From dee4db7d1bcb28839e454cc7ea41d65725b23e24 Mon Sep 17 00:00:00 2001 From: jyyi1 Date: Thu, 9 Nov 2023 00:47:52 -0500 Subject: [PATCH 01/23] feat(transport): general TLS ClientHello fragmentation StreamDialer --- transport/tlsfrag/buffer.go | 307 +++++++++++++++++++++++++++++ transport/tlsfrag/stream_dialer.go | 84 ++++++++ transport/tlsfrag/writer.go | 164 +++++++++++++++ 3 files changed, 555 insertions(+) create mode 100644 transport/tlsfrag/buffer.go create mode 100644 transport/tlsfrag/stream_dialer.go create mode 100644 transport/tlsfrag/writer.go diff --git a/transport/tlsfrag/buffer.go b/transport/tlsfrag/buffer.go new file mode 100644 index 00000000..fd17c1d6 --- /dev/null +++ b/transport/tlsfrag/buffer.go @@ -0,0 +1,307 @@ +// Copyright 2023 Jigsaw Operations LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tlsfrag + +import ( + "encoding/binary" + "errors" + "fmt" + "io" +) + +// TLS record layout from [RFC 8446]: +// +// +-------------+ 0 +// | ContentType | +// +-------------+ 1 +// | Protocol | +// | Version | +// +-------------+ 3 +// | Record | +// | Length | +// +-------------+ 5 +// | Data | +// | ... | +// +-------------+ Record Length + 5 +// +// ContentType := invalid(0) | handshake(22) | application_data(23) | ... +// Protocol Version (deprecated) := 0x0301 ("TLS 1.0") | 0x0303 ("TLS 1.2" & "TLS 1.3") | 0x0302 ("TLS 1.1") +// 0 < Record Length (of handshake) ≤ 2^14 +// 0 ≤ Record Length (of application_data) ≤ 2^14 +// +// [RFC 8446]: https://datatracker.ietf.org/doc/html/rfc8446#section-5.1 +const ( + tlsRecordWithTypeSize = 1 // the minimum size that contains record type + tlsRecordWithVersionHeaderSize = 3 // the minimum size that contains protocol version + tlsRecordHeaderSize = 5 // the minimum size that contains the entire header + tlsTypeHandshake = 22 + tlsMaxRecordLen = 1 << 14 +) + +// errInvalidTLSClientHello is the error used when the data received is not a valid TLS Client Hello. +// Please use [errors.Is] to compare the returned err object with this instance. +var errInvalidTLSClientHello = errors.New("not a valid TLS Client Hello packet") + +func isTLSRecordTypeHandshake(hdr []byte) bool { + return hdr[0] == tlsTypeHandshake +} + +// isValidTLSProtocolVersion determines whether hdr[1:3] is a valid TLS version according to RFC: +// +// """ +// legacy_record_version: +// MUST be set to 0x0303 for all records generated by a TLS 1.3 implementation other than an initial ClientHello, +// where it MAY also be 0x0301 for compatibility purposes. This field is deprecated and MUST be ignored for all +// purposes. Previous versions of TLS would use other values in this field under some circumstances. +// """ +func isValidTLSProtocolVersion(hdr []byte) bool { + return hdr[1] == 0x03 && (0x01 <= hdr[2] && hdr[2] <= 0x03) +} + +func recordLen(hdr []byte) uint16 { + return binary.BigEndian.Uint16(hdr[3:]) +} + +func isValidRecordLenForHandshake(len uint16) bool { + return 0 < len && len <= tlsMaxRecordLen +} + +func putTLSClientHelloHeader(hdr []byte, recordLen uint16) { + _ = hdr[4] // bounds check to guarantee safety of writes below + hdr[0] = tlsTypeHandshake + hdr[1] = 0x03 + hdr[2] = 0x03 + binary.BigEndian.PutUint16(hdr[3:], recordLen) +} + +// clientHelloBuffer is a byte buffer used to receive and send the TLS Client Hello packet. +// This packet can be splitted into two records if needed. +type clientHelloBuffer struct { + data []byte // the buffer that hosts both header and content, len(data) should be either 5 or recordLen+10 + valid bool // indicate whether the content in data is a valid TLS Client Hello record + len int // the number of the bytes that has been read into data + recordLen int // the length of the original (unsplitted) record content (without header) + split int // the 0-based index to split the packet into [:split] and [split:] +} + +// newClientHelloBuffer creates and initializes a new buffer to receive TLS Client Hello packet. +func newClientHelloBuffer() *clientHelloBuffer { + // Allocate the 5 bytes header first, and reallocate it to contain the entire packet later + return &clientHelloBuffer{ + data: make([]byte, tlsRecordHeaderSize), + valid: true, + } +} + +// ReadFrom reads all the data from r and appends it to this buffer until a complete Client Hello packet has been +// received, or r returns EOF or error. It returns the number of bytes read. Any error except EOF encountered during +// the read is also returned. +// +// You can call ReadFrom multiple times if r doesn't provide enough data to build a complete Client Hello packet. +// Call HasFullyReceived to check whether a complete Client Hello packet has been constructed. +func (b *clientHelloBuffer) ReadFrom(r io.Reader) (n int64, err error) { + if !b.valid { + return 0, errInvalidTLSClientHello + } + + if b.len < tlsRecordHeaderSize { + m, e := b.readHeaderFrom(r) + n += int64(m) + if err = e; err == io.EOF { + return n, nil + } + if err != nil { + return + } + } + + if b.len < b.recordLen+tlsRecordHeaderSize { + m, e := b.readContentFrom(r) + n += int64(m) + if err = e; err == io.EOF { + return n, nil + } + } + return +} + +// WriteTo writes all data from this buffer to w until there's no more data or when an error occurs. +// It returns the number of bytes written. Any error encountered during the read is also returned. +// +// Note that the number of bytes written includes both the data read by ReadFrom and any additional headers. +// If you only want to know how many bytes from the last ReadFrom were written, check BytesOverlapped. +func (b *clientHelloBuffer) WriteTo(w io.Writer) (n int64, err error) { + if b.len > 0 { + m, e := w.Write(b.data[:b.len]) + n = int64(m) + if err = e; err != nil { + return + } + // all bytes should have been written, by definition of Write method in io.Writer + if m != b.len { + err = io.ErrShortWrite + } + } + return +} + +// HasFullyReceived returns whether a complete TLS Client Hello packet has been assembled. +func (b *clientHelloBuffer) HasFullyReceived() bool { + return b.valid && b.recordLen > 0 && b.len >= b.recordLen+tlsRecordHeaderSize +} + +// BytesOverlapped returns the number of bytes actually copied from the io.Reader in ReadFrom(r) +// to io.Writer in WriteTo, ignoring any extra headers added by Split. +// +// Here's an example explaining it further: +// +// _, _ := buf.ReadFrom([]byte{1,2}) // {1,2} are appended to buf +// rn, _ := buf.ReadFrom([]byte{3,4,5,6}) // rn == 3, {3,4,5} are appended to buf +// buf.Split(2) // will add some additional header bytes +// // now assume buf contains {1,2,h,h,h,h,h,3,4,5} +// wn, _ := buf.WriteTo(w) // wn == 8, {1,2,h,h,h,h,h,3} are written to w +// n := buf.BytesOverlapped(rn, wn) // n == 1, because only byte {3} comes from the last ReadFrom +func (b *clientHelloBuffer) BytesOverlapped(rn, wn int64) int { + // ndata = 12: 1 2 3 4 h h h h h 5 6 7 + // rn = 5: | | | | | + // wn = 6: | | | | | | + // overlap == 2: ^ ^ + // wn & h: x x x x | | N N N + + if wn < int64(b.split) { + // add all 5 header bytes to wn when splitted and wn doesn't overlap with h + // if no splitting, this condition will never be satifsfied because wn always >= 0 + wn += tlsRecordHeaderSize + } else if b.split > 0 && wn < int64(b.split+tlsRecordHeaderSize) { + // fill all non-overlapped h bytes to wn (bytes marked as N above) when wn partially overlaps with h + wn = int64(b.split + tlsRecordHeaderSize) + } + + // now both wn and n contain either a 5-byte header or no header at all + // the header bytes get cancelled out in the subtraction (wn - ndata) below + // rn + wn = (left+overlap) + (overlap+right) = (left+overlap+right) + overlap = ndata + overlap + if overlap := int(rn) + int(wn) - b.len; overlap >= 0 { + return overlap + } + return 0 +} + +// Content returns the Client Hello packet content (without the 5 bytes header). +// It might return an incomplete content, the caller needs to make sure HasFullyReceived before calling this function. +func (b *clientHelloBuffer) Content() []byte { + if b.len <= tlsRecordHeaderSize { + return []byte{} + } + return b.data[tlsRecordHeaderSize:b.len] +} + +// Split fragments the Client Hello packet into two TLS records at the specified 0-based splitBytes: +// [:splitBytes] and [splitBytes:]. Any necessary headers will be added to this buffer. +// +// If the packet has already be splitted before, a non-nil error and returned. +// If the split index is ≤ 0 or ≥ the total length, do nothing. +func (b *clientHelloBuffer) Split(splitBytes int) error { + if b.split > 0 { + return errors.New("packet has already been fragmented") + } + if !b.HasFullyReceived() || b.len != b.recordLen+tlsRecordHeaderSize { + return errors.New("incomplete packet cannot be fragmented") + } + if splitBytes <= 0 || splitBytes >= b.recordLen { + return nil + } + _ = b.data[b.len+tlsRecordHeaderSize-1] // bounds check to guarantee safety of writes below + + // the 2nd record starting point (including header), and move the 2nd record content 5 bytes to the right + sz2 := b.recordLen - splitBytes + b.split = tlsRecordHeaderSize + splitBytes + b.len += tlsRecordHeaderSize + + if copy(b.data[b.split+tlsRecordHeaderSize:b.len], b.data[b.split:]) != sz2 { + return errors.New("failed to split the second record") + } + + putTLSClientHelloHeader(b.data[0:], uint16(splitBytes)) + putTLSClientHelloHeader(b.data[b.split:], uint16(sz2)) + return nil +} + +// readHeaderFrom read a 5 bytes TLS Client Hello header from r into b.data[0:5]. +func (b *clientHelloBuffer) readHeaderFrom(r io.Reader) (n int, err error) { + if b.len >= tlsRecordHeaderSize { + return 0, errors.New("header has already been read") + } + if len(b.data) < tlsRecordHeaderSize { + return 0, errors.New("insufficient buffer to hold the header") + } + + prevLen := b.len + for err == nil && b.len < tlsRecordHeaderSize { + m, e := r.Read(b.data[b.len:tlsRecordHeaderSize]) + err = e + n += m + b.len += m + } + + if prevLen < tlsRecordWithTypeSize && b.len >= tlsRecordWithTypeSize { + if !isTLSRecordTypeHandshake(b.data) { + b.valid = false + err = errors.Join(err, fmt.Errorf("not a handshake record: %w", errInvalidTLSClientHello)) + } + } + + if prevLen < tlsRecordWithVersionHeaderSize && b.len >= tlsRecordWithVersionHeaderSize { + if !isValidTLSProtocolVersion(b.data) { + b.valid = false + err = errors.Join(err, fmt.Errorf("not a valid TLS version: %w", errInvalidTLSClientHello)) + } + } + + if prevLen < tlsRecordHeaderSize && b.len >= tlsRecordHeaderSize { + if rl := recordLen(b.data); !isValidRecordLenForHandshake(rl) { + b.valid = false + err = errors.Join(err, fmt.Errorf("record length out of range: %w", errInvalidTLSClientHello)) + } else { + b.recordLen = int(rl) + // allocate space for 2 headers and 1 content (might be splitted into two contents) + buf := make([]byte, b.recordLen+tlsRecordHeaderSize*2) + if copy(buf, b.data[:tlsRecordHeaderSize]) != tlsRecordHeaderSize { + err = errors.Join(err, errors.New("failed to copy header data")) + } else { + b.data = buf + } + } + } + return +} + +// readContentFrom read a recordLen bytes TLS Client Hello record content from r into b.data[5:5+recordLen]. +func (b *clientHelloBuffer) readContentFrom(r io.Reader) (n int, err error) { + fullsz := tlsRecordHeaderSize + b.recordLen + if b.len >= fullsz { + return 0, errors.New("content has already been read") + } + if len(b.data) < fullsz { + return 0, errors.New("insufficient buffer to hold the content") + } + + for err == nil && b.len < fullsz { + m, e := r.Read(b.data[b.len:fullsz]) + err = e + n += m + b.len += m + } + return +} diff --git a/transport/tlsfrag/stream_dialer.go b/transport/tlsfrag/stream_dialer.go new file mode 100644 index 00000000..298fad85 --- /dev/null +++ b/transport/tlsfrag/stream_dialer.go @@ -0,0 +1,84 @@ +// Copyright 2023 Jigsaw Operations LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tlsfrag + +import ( + "context" + "errors" + + "github.com/Jigsaw-Code/outline-sdk/transport" +) + +// tlsFragDialer is a [transport.StreamDialer] that uses clientHelloFragWriter to fragment the first Client Hello +// record in a TLS session. +type tlsFragDialer struct { + dialer transport.StreamDialer + frag FragFunc +} + +// Compilation guard against interface implementation +var _ transport.StreamDialer = (*tlsFragDialer)(nil) + +// FragFunc takes the content of the first [handshake record] in a TLS session as input, and returns an integer that +// represents the fragmentation point index. The input content excludes the 5-byte record header. The returned integer +// should be in range 0 to len(record)-1. The record will then be fragmented into two parts: record[:n] and record[n:]. +// If the returned index is either ≤ 0 or ≥ len(record), no fragmentation will occur. +// +// [handshake record]: https://datatracker.ietf.org/doc/html/rfc8446#section-5.1 +type FragFunc func(record []byte) int + +// NewStreamDialerFunc creates a [transport.StreamDialer] that intercepts the initial [TLS Client Hello] +// [handshake record] and splits it into two separate records before sending them. The split point is determined by the +// callback function frag. The dialer then adds appropriate headers to each record and transmits them sequentially +// using the base dialer. Following the fragmented Client Hello, all subsequent data is passed through directly without +// modification. +// +// NewStreamDialerFunc allows specifying additional options to customize its behavior. By default, if no options are +// specified, the fragmentation only affects TLS Client Hello messages targeting port 443. All other network traffic, +// including non-TLS or non-Client Hello messages, or those targeting other ports, are passed through without any +// modification. +// +// [TLS Client Hello]: https://datatracker.ietf.org/doc/html/rfc8446#section-4.1.2 +// [handshake record]: https://datatracker.ietf.org/doc/html/rfc8446#section-5.1 +func NewStreamDialerFunc(base transport.StreamDialer, frag FragFunc) (transport.StreamDialer, error) { + if base == nil { + return nil, errors.New("base dialer must not be nil") + } + if frag == nil { + return nil, errors.New("frag function must not be nil") + } + return &tlsFragDialer{base, frag}, nil +} + +// WithTLSAddrList + +// Dial implements [transport.StreamConn].Dial. It establishes a connection to raddr in the format "host-or-ip:port". +// +// If raddr matches an entry in the valid TLS address list (which can be configured using [WithTLSAddrList]), the +// initial TLS Client Hello record sent through the connection will be fragmented. +// +// If raddr is not listed in the valid TLS address list, the function simply utilizes the underlying base dialer's Dial +// function to establish the connection without any fragmentation. +func (d *tlsFragDialer) Dial(ctx context.Context, raddr string) (transport.StreamConn, error) { + innerConn, err := d.dialer.Dial(ctx, raddr) + if err != nil { + return nil, err + } + w, err := newClientHelloFragWriter(innerConn, d.frag) + if err != nil { + return nil, err + } + return transport.WrapConn(innerConn, innerConn, w), nil +} diff --git a/transport/tlsfrag/writer.go b/transport/tlsfrag/writer.go new file mode 100644 index 00000000..f00c38e7 --- /dev/null +++ b/transport/tlsfrag/writer.go @@ -0,0 +1,164 @@ +// Copyright 2023 Jigsaw Operations LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tlsfrag + +import ( + "bytes" + "errors" + "io" +) + +// clientHelloFragWriter intercepts the initial TLS Client Hello record and splits it into two TLS records based on the +// return value of frag function. These fragmented records are then written to the base [io.Writer]. Subsequent packets +// are not modified and are directly transmitted through the base [io.Writer]. +type clientHelloFragWriter struct { + base io.Writer + done bool + frag FragFunc + buf *clientHelloBuffer +} + +// clientHelloFragReaderFrom serves as an optimized version of clientHelloFragWriter when the base [io.Writer] also +// implements the [io.ReaderFrom] interface. +type clientHelloFragReaderFrom struct { + *clientHelloFragWriter + baseRF io.ReaderFrom +} + +// Compilation guard against interface implementation +var _ io.Writer = (*clientHelloFragWriter)(nil) +var _ io.Writer = (*clientHelloFragReaderFrom)(nil) +var _ io.ReaderFrom = (*clientHelloFragReaderFrom)(nil) + +// newClientHelloFragWriter creates a [io.Writer] that splits the first TLS Client Hello record into two records based +// on the provided frag function. It then writes these records and all subsequent messages to the base [io.Writer]. +// If the first message isn't a Client Hello, no splitting occurs and all messages are written directly to base. +// +// The returned [io.Writer] will implement the [io.ReaderFrom] interface for optimized performance if the base +// [io.Writer] implements [io.ReaderFrom]. +func newClientHelloFragWriter(base io.Writer, frag FragFunc) (io.Writer, error) { + if base == nil { + return nil, errors.New("base writer must not be nil") + } + if frag == nil { + return nil, errors.New("frag callback function must not be nil") + } + fw := &clientHelloFragWriter{ + base: base, + frag: frag, + buf: newClientHelloBuffer(), + } + if rf, ok := base.(io.ReaderFrom); ok { + return &clientHelloFragReaderFrom{fw, rf}, nil + } + return fw, nil +} + +// Write implements io.Writer.Write. It attempts to split the data received in the first one or more Write call(s) +// into two TLS records if the data corresponds to a TLS Client Hello record. +// +// Internally, this function maintains a state machine with the following states: +// - S: reading the first client hello record and appending the data to w.buf +// - F: the first client hello record has been read, fragmenting and writing to w.base +// - T: forwarding all remaining packets without modification +// +// Here is the transition graph: +// +// S ----(full handshake read)----> F -----> T +// | ^ +// | | +// +-----(invalid TLS handshake)-------------+ +func (w *clientHelloFragWriter) Write(p []byte) (written int, err error) { + // T: optimize to have fewer comparisons for the most common case. + if w.done { + return w.base.Write(p) + } + + // S + nr, e := w.buf.ReadFrom(bytes.NewBuffer(p)) + + // S -> T + if errors.Is(e, errInvalidTLSClientHello) { + goto FlushBufAndDone + } + + // S < x < F, wait for the next write + if e != nil || !w.buf.HasFullyReceived() { + return int(nr), e + } + + // F + if err = w.buf.Split(w.frag(w.buf.Content())); err != nil { + return int(nr), err + } + + // * -> T (err must be nil) +FlushBufAndDone: + w.done = true + nw, e := w.buf.WriteTo(w.base) + written += w.buf.BytesOverlapped(nr, nw) + w.buf = nil // allows the GC to recycle the memory + + // If WriteTo failed, no need to write more data + if err = e; err != nil { + return + } + + m, e := w.base.Write(p[nr:]) + written += m + err = e + return +} + +// ReadFrom implements io.ReaderFrom.ReadFrom. It attempts to split the first packet into two TLS records if the data +// corresponds to a TLS Client Hello record. And then copies the remaining data from r to the base io.Writer until EOF +// or error. +// +// If the first packet is not a valid TLS Client Hello, everything from r gets copied to the base io.Writer as is. +// +// It returns the number of bytes read. Any error except EOF encountered during the read is also returned. +// +// Internally, it uses a similar state machine to the one mentioned in w.Write. But the transition is simplier because +// we expect r containing all the data (while the first packet might be consisted of multiple Writes in Write). +func (w *clientHelloFragReaderFrom) ReadFrom(r io.Reader) (n int64, err error) { + // T + if w.done { + return w.baseRF.ReadFrom(r) + } + + // S & F + nr, e := w.buf.ReadFrom(r) + if err = e; err == nil && w.buf.HasFullyReceived() { + err = w.buf.Split(w.frag(w.buf.Content())) + } else if errors.Is(err, errInvalidTLSClientHello) { + err = nil + } + + // * -> T (err might be non-nil, but we still need to flush data to w.base) + w.done = true + nw, e := w.buf.WriteTo(w.base) + n += int64(w.buf.BytesOverlapped(nr, nw)) + w.buf = nil // allows the GC to recycle the memory + + // If WriteTo failed, no need to write more data + if err = errors.Join(err, e); e != nil { + return + } + + m, e := w.baseRF.ReadFrom(r) + n += m + err = e + return +} From 603eaf9b3b88d79c830d8f9f2e1c0d2ce688ae74 Mon Sep 17 00:00:00 2001 From: jyyi1 Date: Thu, 9 Nov 2023 17:49:46 -0500 Subject: [PATCH 02/23] add ReadFrom tests to buffer_test.go --- transport/tlsfrag/buffer.go | 2 +- transport/tlsfrag/buffer_test.go | 131 +++++++++++++++++++++++++++++++ 2 files changed, 132 insertions(+), 1 deletion(-) create mode 100644 transport/tlsfrag/buffer_test.go diff --git a/transport/tlsfrag/buffer.go b/transport/tlsfrag/buffer.go index fd17c1d6..9bd76bd5 100644 --- a/transport/tlsfrag/buffer.go +++ b/transport/tlsfrag/buffer.go @@ -82,7 +82,7 @@ func putTLSClientHelloHeader(hdr []byte, recordLen uint16) { _ = hdr[4] // bounds check to guarantee safety of writes below hdr[0] = tlsTypeHandshake hdr[1] = 0x03 - hdr[2] = 0x03 + hdr[2] = 0x01 binary.BigEndian.PutUint16(hdr[3:], recordLen) } diff --git a/transport/tlsfrag/buffer_test.go b/transport/tlsfrag/buffer_test.go new file mode 100644 index 00000000..d911c515 --- /dev/null +++ b/transport/tlsfrag/buffer_test.go @@ -0,0 +1,131 @@ +// Copyright 2023 Jigsaw Operations LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tlsfrag + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/require" +) + +// Example TLS Client Hello packet copied from https://tls13.xargs.org/#client-hello +// Total Len = 253, ContentLen = 253 - 5 = 248 +var exampleTLS13ClientHello = []byte{ + 0x16, 0x03, 0x01, 0x00, 0xf8, 0x01, 0x00, 0x00, 0xf4, 0x03, 0x03, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, + 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, + 0x1d, 0x1e, 0x1f, 0x20, 0xe0, 0xe1, 0xe2, 0xe3, 0xe4, 0xe5, 0xe6, 0xe7, 0xe8, 0xe9, 0xea, 0xeb, 0xec, 0xed, 0xee, 0xef, + 0xf0, 0xf1, 0xf2, 0xf3, 0xf4, 0xf5, 0xf6, 0xf7, 0xf8, 0xf9, 0xfa, 0xfb, 0xfc, 0xfd, 0xfe, 0xff, 0x00, 0x08, 0x13, 0x02, + 0x13, 0x03, 0x13, 0x01, 0x00, 0xff, 0x01, 0x00, 0x00, 0xa3, 0x00, 0x00, 0x00, 0x18, 0x00, 0x16, 0x00, 0x00, 0x13, 0x65, + 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x2e, 0x75, 0x6c, 0x66, 0x68, 0x65, 0x69, 0x6d, 0x2e, 0x6e, 0x65, 0x74, 0x00, 0x0b, + 0x00, 0x04, 0x03, 0x00, 0x01, 0x02, 0x00, 0x0a, 0x00, 0x16, 0x00, 0x14, 0x00, 0x1d, 0x00, 0x17, 0x00, 0x1e, 0x00, 0x19, + 0x00, 0x18, 0x01, 0x00, 0x01, 0x01, 0x01, 0x02, 0x01, 0x03, 0x01, 0x04, 0x00, 0x23, 0x00, 0x00, 0x00, 0x16, 0x00, 0x00, + 0x00, 0x17, 0x00, 0x00, 0x00, 0x0d, 0x00, 0x1e, 0x00, 0x1c, 0x04, 0x03, 0x05, 0x03, 0x06, 0x03, 0x08, 0x07, 0x08, 0x08, + 0x08, 0x09, 0x08, 0x0a, 0x08, 0x0b, 0x08, 0x04, 0x08, 0x05, 0x08, 0x06, 0x04, 0x01, 0x05, 0x01, 0x06, 0x01, 0x00, 0x2b, + 0x00, 0x03, 0x02, 0x03, 0x04, 0x00, 0x2d, 0x00, 0x02, 0x01, 0x01, 0x00, 0x33, 0x00, 0x26, 0x00, 0x24, 0x00, 0x1d, 0x00, + 0x20, 0x35, 0x80, 0x72, 0xd6, 0x36, 0x58, 0x80, 0xd1, 0xae, 0xea, 0x32, 0x9a, 0xdf, 0x91, 0x21, 0x38, 0x38, 0x51, 0xed, + 0x21, 0xa2, 0x8e, 0x3b, 0x75, 0xe9, 0x65, 0xd0, 0xd2, 0xcd, 0x16, 0x62, 0x54, +} + +// Test ReadFrom a Reader containing exactly one Client Hello record. +func TestReadFromClientHelloSingleReader(t *testing.T) { + buf := newClientHelloBuffer() + r := bytes.NewBuffer(exampleTLS13ClientHello) + require.Equal(t, 253, r.Len()) + + n, err := buf.ReadFrom(r) + require.NoError(t, err) + require.Equal(t, 253, int(n)) + require.Equal(t, 0, r.Len()) + + require.True(t, buf.HasFullyReceived()) + require.Equal(t, exampleTLS13ClientHello[5:], buf.Content()) + require.Equal(t, 248, len(buf.Content())) +} + +// Test ReadFrom multiple Readers containing exactly one Client Hello record. +func TestReadFromClientHelloMultipleReaders(t *testing.T) { + buf := newClientHelloBuffer() + r1 := bytes.NewBuffer(exampleTLS13ClientHello[:2]) + r2 := bytes.NewBuffer(exampleTLS13ClientHello[2:123]) + r3 := bytes.NewBuffer(exampleTLS13ClientHello[123:]) + + require.Equal(t, 2, r1.Len()) + n, err := buf.ReadFrom(r1) + require.NoError(t, err) + require.Equal(t, 2, int(n)) + require.Equal(t, 0, r1.Len()) + require.False(t, buf.HasFullyReceived()) + + require.Equal(t, 123-2, r2.Len()) + n, err = buf.ReadFrom(r2) + require.NoError(t, err) + require.Equal(t, 123-2, int(n)) + require.Equal(t, 0, r2.Len()) + require.False(t, buf.HasFullyReceived()) + + require.Equal(t, 253-123, r3.Len()) + n, err = buf.ReadFrom(r3) + require.NoError(t, err) + require.Equal(t, 253-123, int(n)) + require.Equal(t, 0, r3.Len()) + + require.True(t, buf.HasFullyReceived()) + require.Equal(t, exampleTLS13ClientHello[5:], buf.Content()) + require.Equal(t, 248, len(buf.Content())) +} + +// Test ReadFrom a Reader containing Client Hello and 8 more extra bytes. +func TestReadFromClientHelloExtraBytesSingleReader(t *testing.T) { + buf := newClientHelloBuffer() + r := bytes.NewBuffer(append(exampleTLS13ClientHello, 0x88, 0x87, 0x86, 0x85, 0x84, 0x83, 0x82, 0x81)) + require.Equal(t, 261, r.Len()) + + n, err := buf.ReadFrom(r) + require.NoError(t, err) + require.Equal(t, 253, int(n)) + require.Equal(t, 8, r.Len()) + + require.True(t, buf.HasFullyReceived()) + require.Equal(t, exampleTLS13ClientHello[5:], buf.Content()) + require.Equal(t, 248, len(buf.Content())) + require.Equal(t, []byte{0x88, 0x87, 0x86, 0x85, 0x84, 0x83, 0x82, 0x81}, r.Bytes()) +} + +// Test ReadFrom multiple Readers containing Client Hello and 8 more extra bytes. +func TestReadFromClientHelloExtraBytesMultipleReaders(t *testing.T) { + buf := newClientHelloBuffer() + pkt := append(exampleTLS13ClientHello, 0x88, 0x87, 0x86, 0x85, 0x84, 0x83, 0x82, 0x81) + r1 := bytes.NewBuffer(pkt[:123]) + r2 := bytes.NewBuffer(pkt[123:]) + + require.Equal(t, 123, r1.Len()) + n, err := buf.ReadFrom(r1) + require.NoError(t, err) + require.Equal(t, 123, int(n)) + require.Equal(t, 0, r1.Len()) + require.False(t, buf.HasFullyReceived()) + + require.Equal(t, 261-123, r2.Len()) + n, err = buf.ReadFrom(r2) + require.NoError(t, err) + require.Equal(t, 261-123-8, int(n)) + require.Equal(t, 8, r2.Len()) + + require.True(t, buf.HasFullyReceived()) + require.Equal(t, exampleTLS13ClientHello[5:], buf.Content()) + require.Equal(t, 248, len(buf.Content())) + require.Equal(t, []byte{0x88, 0x87, 0x86, 0x85, 0x84, 0x83, 0x82, 0x81}, r2.Bytes()) +} From 779fe682f7133c01af2614ce1b1a92617c1c66b9 Mon Sep 17 00:00:00 2001 From: jyyi1 Date: Fri, 10 Nov 2023 16:28:17 -0500 Subject: [PATCH 03/23] add TLS address list option to the dialer --- transport/tlsfrag/stream_dialer.go | 119 +++++++++++++- transport/tlsfrag/stream_dialer_test.go | 202 ++++++++++++++++++++++++ 2 files changed, 313 insertions(+), 8 deletions(-) create mode 100644 transport/tlsfrag/stream_dialer_test.go diff --git a/transport/tlsfrag/stream_dialer.go b/transport/tlsfrag/stream_dialer.go index 298fad85..965a8668 100644 --- a/transport/tlsfrag/stream_dialer.go +++ b/transport/tlsfrag/stream_dialer.go @@ -17,6 +17,10 @@ package tlsfrag import ( "context" "errors" + "fmt" + "net" + "strconv" + "strings" "github.com/Jigsaw-Code/outline-sdk/transport" ) @@ -26,6 +30,7 @@ import ( type tlsFragDialer struct { dialer transport.StreamDialer frag FragFunc + config *DialerConfiguration } // Compilation guard against interface implementation @@ -39,6 +44,18 @@ var _ transport.StreamDialer = (*tlsFragDialer)(nil) // [handshake record]: https://datatracker.ietf.org/doc/html/rfc8446#section-5.1 type FragFunc func(record []byte) int +// DialerConfiguration is an internal type used to configure the [transport.StreamDialer] created by +// [NewStreamDialerFunc]. You don't need to work with it directly. Instead, use the provided configuration functions +// like [WithTLSHostPortList]. +type DialerConfiguration struct { + addrs []*tlsAddrEntry +} + +// DialerConfigurer updates the settings in the internal DialerConfiguration object. You can use the configuration +// functions such as [WithTLSHostPortList] to create configurers and then pass them to NewStreamDialerFunc to create a +// [transport.StreamDialer] with your desired configuration. +type DialerConfigurer func(*DialerConfiguration) error + // NewStreamDialerFunc creates a [transport.StreamDialer] that intercepts the initial [TLS Client Hello] // [handshake record] and splits it into two separate records before sending them. The split point is determined by the // callback function frag. The dialer then adds appropriate headers to each record and transmits them sequentially @@ -52,33 +69,119 @@ type FragFunc func(record []byte) int // // [TLS Client Hello]: https://datatracker.ietf.org/doc/html/rfc8446#section-4.1.2 // [handshake record]: https://datatracker.ietf.org/doc/html/rfc8446#section-5.1 -func NewStreamDialerFunc(base transport.StreamDialer, frag FragFunc) (transport.StreamDialer, error) { +func NewStreamDialerFunc(base transport.StreamDialer, frag FragFunc, options ...DialerConfigurer) (transport.StreamDialer, error) { if base == nil { return nil, errors.New("base dialer must not be nil") } if frag == nil { return nil, errors.New("frag function must not be nil") } - return &tlsFragDialer{base, frag}, nil + config := &DialerConfiguration{ + addrs: []*tlsAddrEntry{{"", 443}}, + } + for _, opt := range options { + if opt != nil { + if err := opt(config); err != nil { + return nil, err + } + } + } + return &tlsFragDialer{base, frag, config}, nil } -// WithTLSAddrList +// WithTLSHostPortList tells the [transport.StreamDialer] which connections to treat as TLS. Only connections matching +// entries in the tlsAddrs list will be treated as TLS traffic and fragmented accordingly. +// +// Each entry in the tlsAddrs list should be in the format "host:port", where "host" can be an IP address or a domain +// name, and "port" must be a valid port number. You can use empty string "" as the "host" to only match based on the +// port, and "0" as the "port" to match any port. +// +// The default list only includes ":443", meaning all traffic on port 443 is treated as TLS. This function overrides +// the entire list. So if you want to add entries, you need to include ":443" along with your additional entries. +// +// Matching for "host" is case-insensitive and strict. For example, "google.com:123" will only match "google.com" and +// not "www.google.com". Subdomain wildcards are not supported. +func WithTLSHostPortList(tlsAddrs []string) DialerConfigurer { + return func(c *DialerConfiguration) error { + addrs := make([]*tlsAddrEntry, 0, len(tlsAddrs)) + for _, hostport := range tlsAddrs { + addr, err := parseTLSAddrEntry(hostport) + if err != nil { + return err + } + addrs = append(addrs, addr) + } + c.addrs = addrs + return nil + } +} // Dial implements [transport.StreamConn].Dial. It establishes a connection to raddr in the format "host-or-ip:port". // -// If raddr matches an entry in the valid TLS address list (which can be configured using [WithTLSAddrList]), the +// If raddr matches an entry in the valid TLS address list (which can be configured using [WithTLSHostPortList]), the // initial TLS Client Hello record sent through the connection will be fragmented. // // If raddr is not listed in the valid TLS address list, the function simply utilizes the underlying base dialer's Dial // function to establish the connection without any fragmentation. -func (d *tlsFragDialer) Dial(ctx context.Context, raddr string) (transport.StreamConn, error) { - innerConn, err := d.dialer.Dial(ctx, raddr) +func (d *tlsFragDialer) Dial(ctx context.Context, raddr string) (conn transport.StreamConn, err error) { + conn, err = d.dialer.Dial(ctx, raddr) + if err != nil { + return + } + for _, addr := range d.config.addrs { + if addr.matches(raddr) { + w, err := newClientHelloFragWriter(conn, d.frag) + if err != nil { + return nil, err + } + return transport.WrapConn(conn, conn, w), nil + } + } + return +} + +// tlsAddrEntry reprsents an entry of the TLS traffic list. See [WithTLSHostPortList]. +type tlsAddrEntry struct { + host string + port int +} + +// parseTLSAddrEntry parses hostport in format "host:port" and returns the corresponding tlsAddrEntry. +func parseTLSAddrEntry(hostport string) (*tlsAddrEntry, error) { + host, portStr, err := net.SplitHostPort(hostport) if err != nil { return nil, err } - w, err := newClientHelloFragWriter(innerConn, d.frag) + port, err := strconv.Atoi(portStr) if err != nil { return nil, err } - return transport.WrapConn(innerConn, innerConn, w), nil + if port < 0 || port > 65535 { + return nil, fmt.Errorf("port must be within 0-65535: %w", strconv.ErrRange) + } + return &tlsAddrEntry{host, port}, nil +} + +// matches returns whether raddr matches this entry. +func (e *tlsAddrEntry) matches(raddr string) bool { + if len(e.host) == 0 && e.port == 0 { + return true + } + host, portStr, err := net.SplitHostPort(raddr) + if err != nil { + return false + } + if len(e.host) > 0 && !strings.EqualFold(e.host, host) { + return false + } + if e.port > 0 { + port, err := strconv.Atoi(portStr) + if err != nil { + return false + } + if e.port != port { + return false + } + } + return true } diff --git a/transport/tlsfrag/stream_dialer_test.go b/transport/tlsfrag/stream_dialer_test.go new file mode 100644 index 00000000..eadccbec --- /dev/null +++ b/transport/tlsfrag/stream_dialer_test.go @@ -0,0 +1,202 @@ +// Copyright 2023 Jigsaw Operations LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tlsfrag + +import ( + "context" + "strconv" + "testing" + + "github.com/Jigsaw-Code/outline-sdk/transport" + "github.com/stretchr/testify/require" +) + +// this is the local conn that can be shared across tests +var theLocalConn = &localConn{} + +// Make sure NewStreamDialer returns error on invalid WithTLSHostPortList calls. +func TestNewStreamDialerWithInvalidTLSAddr(t *testing.T) { + cases := []struct { + addr string + errType error // nil indicates general error + }{ + {"1.2.3.4", nil}, + {":::::", nil}, + {"1.2.3.4:654-321", strconv.ErrSyntax}, + {"1.2.3.4:--8080", strconv.ErrSyntax}, + {"[::]:10000000000", strconv.ErrRange}, + {"1.2.3.4:-1234", strconv.ErrRange}, + {":654321", strconv.ErrRange}, + } + for _, tc := range cases { + d, err := NewStreamDialerFunc(localConnDialer{}, func([]byte) int { return 0 }, WithTLSHostPortList([]string{tc.addr})) + require.Error(t, err, tc.addr) + if tc.errType != nil { + require.ErrorIs(t, err, tc.errType, tc.addr) + } + require.Nil(t, d) + } +} + +// Make sure no fragmentation connection is created if raddr is not in the allowed list. +func TestDialFragmentOnTLSAddrOnly(t *testing.T) { + tlsAddrs := []string{ + ":443", // default entry + ":990", // additional FTPS port + ":853", // additional DNS-over-TLS port + "pop.gmail.com:995", // Gmail pop3 + } + cases := []struct { + msg string + raddrs []string + shouldFrag bool + shouldFragWithList bool + }{ + { + msg: "*:443 should be fragmented, raddr = %s", + raddrs: []string{"example.com:443", "66.77.88.99:443", "[2001:db8::1]:443"}, + shouldFrag: true, + shouldFragWithList: true, + }, + { + msg: "*:990 should be fragmented by allowlist, raddr = %s", + raddrs: []string{"my-test.org:990", "192.168.1.10:990", "[2001:db8:3333:4444:5555:6666:7777:8888]:990"}, + shouldFrag: false, + shouldFragWithList: true, + }, + { + msg: "*:8080 should not be fragmented, raddr = %s", + raddrs: []string{"google.com:8080", "64.233.191.255:8080", "[2001:db8:3333:4444:5555:6666:7777:8888]:8080"}, + shouldFrag: false, + shouldFragWithList: false, + }, + { + msg: "DNS ports should not be fragmented, raddr = %s", + raddrs: []string{"8.8.8.8:53", "8.8.4.4:53", "2001:4860:4860::8888", "2001:4860:4860::8844"}, + shouldFrag: false, + shouldFragWithList: false, + }, + { + msg: "DNS over TLS ports should be fragmented by allowlist, raddr = %s", + raddrs: []string{"9.9.9.9:853", "8.8.4.4:853", "[2001:4860:4860::8844]:853", "[2620:fe::fe]:853"}, + shouldFrag: false, + shouldFragWithList: true, + }, + { + msg: "only gmail POP3 should be fragmented by allowlist, raddr = %s", + raddrs: []string{"pop.GMail.com:995"}, + shouldFrag: false, + shouldFragWithList: true, + }, + { + msg: "non-gmail POP3 should not be fragmented, raddr = %s", + raddrs: []string{"8.8.8.8:995", "outlook.office365.com:995", "outlook.office365.com:993", "pop.gmail.com:993"}, + shouldFrag: false, + shouldFragWithList: false, + }, + } + + base := localConnDialer{} + assertShouldFrag := func(conn transport.StreamConn, msg, addr string) { + prevWrCnt := theLocalConn.writeCount + // this Write should not be pushed to theLocalConn yet because it's a valid TLS handshake + conn.Write([]byte{22}) + + nonFragConn, ok := conn.(*localConn) + require.False(t, ok, msg, addr) + require.Nil(t, nonFragConn, msg) + require.Equal(t, prevWrCnt, theLocalConn.writeCount, msg, addr) + } + assertShouldNotFrag := func(conn transport.StreamConn, msg, addr string) { + prevWrCnt := theLocalConn.writeCount + // this Write should be pushed to theLocalConn because it's a direct Write call + conn.Write([]byte{22}) + + nonFragConn, ok := conn.(*localConn) + require.True(t, ok, msg, addr) + require.NotNil(t, nonFragConn, msg, addr) + require.Equal(t, theLocalConn, nonFragConn) + require.Equal(t, prevWrCnt+1, theLocalConn.writeCount, msg, addr) + } + + // default dialer + d1, err := NewStreamDialerFunc(base, func([]byte) int { return 0 }) + require.NoError(t, err) + require.NotNil(t, d1) + + // with additional tls addrs + d2, err := NewStreamDialerFunc(base, func([]byte) int { return 0 }, WithTLSHostPortList(tlsAddrs)) + require.NoError(t, err) + require.NotNil(t, d2) + + // with no tls addrs + d3, err := NewStreamDialerFunc(base, func([]byte) int { return 0 }, WithTLSHostPortList([]string{})) + require.NoError(t, err) + require.NotNil(t, d3) + + // all traffic + d4, err := NewStreamDialerFunc(base, func([]byte) int { return 0 }, WithTLSHostPortList([]string{":0"})) + require.NoError(t, err) + require.NotNil(t, d4) + + for _, tc := range cases { + for _, addr := range tc.raddrs { + conn, err := d1.Dial(context.Background(), addr) + require.NoError(t, err, tc.msg, addr) + require.NotNil(t, conn, tc.msg, addr) + if tc.shouldFrag { + assertShouldFrag(conn, tc.msg, addr) + } else { + assertShouldNotFrag(conn, tc.msg, addr) + } + + conn, err = d2.Dial(context.Background(), addr) + require.NoError(t, err, tc.msg, addr) + require.NotNil(t, conn, tc.msg, addr) + if tc.shouldFragWithList { + assertShouldFrag(conn, tc.msg, addr) + } else { + assertShouldNotFrag(conn, tc.msg, addr) + } + + conn, err = d3.Dial(context.Background(), addr) + require.NoError(t, err, tc.msg, addr) + require.NotNil(t, conn, tc.msg, addr) + assertShouldNotFrag(conn, tc.msg, addr) + + conn, err = d4.Dial(context.Background(), addr) + require.NoError(t, err, tc.msg, addr) + require.NotNil(t, conn, tc.msg, addr) + assertShouldFrag(conn, tc.msg, addr) + } + } +} + +// testing utilitites + +type localConnDialer struct{} +type localConn struct { + transport.StreamConn + writeCount int +} + +func (localConnDialer) Dial(ctx context.Context, raddr string) (transport.StreamConn, error) { + return theLocalConn, nil +} + +func (lc *localConn) Write(b []byte) (n int, err error) { + lc.writeCount++ + return len(b), nil +} From 1ff2fedbdf436f9cd339ec63a128d03264afc598 Mon Sep 17 00:00:00 2001 From: jyyi1 Date: Mon, 13 Nov 2023 15:32:58 -0500 Subject: [PATCH 04/23] add wrapconn --- transport/tlsfrag/stream_dialer.go | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/transport/tlsfrag/stream_dialer.go b/transport/tlsfrag/stream_dialer.go index 965a8668..38a5991e 100644 --- a/transport/tlsfrag/stream_dialer.go +++ b/transport/tlsfrag/stream_dialer.go @@ -130,16 +130,23 @@ func (d *tlsFragDialer) Dial(ctx context.Context, raddr string) (conn transport. } for _, addr := range d.config.addrs { if addr.matches(raddr) { - w, err := newClientHelloFragWriter(conn, d.frag) - if err != nil { - return nil, err - } - return transport.WrapConn(conn, conn, w), nil + return WrapConnFunc(conn, d.frag) } } return } +// WrapConnFunc wraps the base [transport.StreamConn] and splits the first TLS Client Hello packet into two records +// according to the frag function. Subsequent data is forwarded without modification. If the first packet isn't a valid +// Client Hello, WrapConnFunc simply forwards all data through transparently. +func WrapConnFunc(base transport.StreamConn, frag FragFunc) (transport.StreamConn, error) { + w, err := newClientHelloFragWriter(base, frag) + if err != nil { + return nil, err + } + return transport.WrapConn(base, base, w), nil +} + // tlsAddrEntry reprsents an entry of the TLS traffic list. See [WithTLSHostPortList]. type tlsAddrEntry struct { host string From 725c3fe1f40f26080c927e15724710b90a7dad47 Mon Sep 17 00:00:00 2001 From: jyyi1 Date: Wed, 15 Nov 2023 14:18:30 -0500 Subject: [PATCH 05/23] 1 round of refactoring --- transport/tlsfrag/buffer.go | 337 +++++++++--------------------------- transport/tlsfrag/tls.go | 116 +++++++++++++ transport/tlsfrag/writer.go | 135 ++++++++------- 3 files changed, 268 insertions(+), 320 deletions(-) create mode 100644 transport/tlsfrag/tls.go diff --git a/transport/tlsfrag/buffer.go b/transport/tlsfrag/buffer.go index 9bd76bd5..4279c590 100644 --- a/transport/tlsfrag/buffer.go +++ b/transport/tlsfrag/buffer.go @@ -15,293 +15,120 @@ package tlsfrag import ( - "encoding/binary" "errors" "fmt" "io" ) -// TLS record layout from [RFC 8446]: -// -// +-------------+ 0 -// | ContentType | -// +-------------+ 1 -// | Protocol | -// | Version | -// +-------------+ 3 -// | Record | -// | Length | -// +-------------+ 5 -// | Data | -// | ... | -// +-------------+ Record Length + 5 -// -// ContentType := invalid(0) | handshake(22) | application_data(23) | ... -// Protocol Version (deprecated) := 0x0301 ("TLS 1.0") | 0x0303 ("TLS 1.2" & "TLS 1.3") | 0x0302 ("TLS 1.1") -// 0 < Record Length (of handshake) ≤ 2^14 -// 0 ≤ Record Length (of application_data) ≤ 2^14 -// -// [RFC 8446]: https://datatracker.ietf.org/doc/html/rfc8446#section-5.1 -const ( - tlsRecordWithTypeSize = 1 // the minimum size that contains record type - tlsRecordWithVersionHeaderSize = 3 // the minimum size that contains protocol version - tlsRecordHeaderSize = 5 // the minimum size that contains the entire header - tlsTypeHandshake = 22 - tlsMaxRecordLen = 1 << 14 -) - -// errInvalidTLSClientHello is the error used when the data received is not a valid TLS Client Hello. -// Please use [errors.Is] to compare the returned err object with this instance. -var errInvalidTLSClientHello = errors.New("not a valid TLS Client Hello packet") +var ( + // errTLSClientHelloFullyReceived is returned when a full TLS Client Hello has been received and no + // more data can be pushed to the buffer. + errTLSClientHelloFullyReceived = errors.New("already received a complete TLS Client Hello packet") -func isTLSRecordTypeHandshake(hdr []byte) bool { - return hdr[0] == tlsTypeHandshake -} + // errInvalidTLSClientHello is the error used when the data received is not a valid TLS Client Hello. + // Please use [errors.Is] to compare the returned err object with this instance. + errInvalidTLSClientHello = errors.New("not a valid TLS Client Hello packet") +) -// isValidTLSProtocolVersion determines whether hdr[1:3] is a valid TLS version according to RFC: -// -// """ -// legacy_record_version: -// MUST be set to 0x0303 for all records generated by a TLS 1.3 implementation other than an initial ClientHello, -// where it MAY also be 0x0301 for compatibility purposes. This field is deprecated and MUST be ignored for all -// purposes. Previous versions of TLS would use other values in this field under some circumstances. -// """ -func isValidTLSProtocolVersion(hdr []byte) bool { - return hdr[1] == 0x03 && (0x01 <= hdr[2] && hdr[2] <= 0x03) +// clientHelloBuffer is a byte buffer used to receive and buffer a TLS Client Hello packet. +type clientHelloBuffer struct { + data []byte // the buffer that hosts both header and content, cap: 5 -> 5+len(content) + len int // the number of bytes that have already been read into data + valid bool // indicate whether the content in data is a valid TLS Client Hello record + toRead int // the number of bytes to read next, e.g. 1 -> 2 -> 2 -> len(content) } -func recordLen(hdr []byte) uint16 { - return binary.BigEndian.Uint16(hdr[3:]) -} +var _ io.Writer = (*clientHelloBuffer)(nil) +var _ io.ReaderFrom = (*clientHelloBuffer)(nil) -func isValidRecordLenForHandshake(len uint16) bool { - return 0 < len && len <= tlsMaxRecordLen +// newClientHelloBuffer creates and initializes a new buffer to receive a TLS Client Hello packet. +func newClientHelloBuffer() *clientHelloBuffer { + // Allocate the 5 bytes header first, and then reallocate it to contain the entire packet later + return &clientHelloBuffer{ + data: make([]byte, recordHeaderLen), + len: 0, + valid: true, + toRead: tlsRecordWithTypeSize, + } } -func putTLSClientHelloHeader(hdr []byte, recordLen uint16) { - _ = hdr[4] // bounds check to guarantee safety of writes below - hdr[0] = tlsTypeHandshake - hdr[1] = 0x03 - hdr[2] = 0x01 - binary.BigEndian.PutUint16(hdr[3:], recordLen) +// Len returns the length of this buffer including both the 5 bytes header and the content. +func (b *clientHelloBuffer) Len() int { + return b.len } -// clientHelloBuffer is a byte buffer used to receive and send the TLS Client Hello packet. -// This packet can be splitted into two records if needed. -type clientHelloBuffer struct { - data []byte // the buffer that hosts both header and content, len(data) should be either 5 or recordLen+10 - valid bool // indicate whether the content in data is a valid TLS Client Hello record - len int // the number of the bytes that has been read into data - recordLen int // the length of the original (unsplitted) record content (without header) - split int // the 0-based index to split the packet into [:split] and [split:] +// Bytes returns the full Client Hello packet including both the 5 bytes header and the content. +func (b *clientHelloBuffer) Bytes() []byte { + return b.data[:b.len] } -// newClientHelloBuffer creates and initializes a new buffer to receive TLS Client Hello packet. -func newClientHelloBuffer() *clientHelloBuffer { - // Allocate the 5 bytes header first, and reallocate it to contain the entire packet later - return &clientHelloBuffer{ - data: make([]byte, tlsRecordHeaderSize), - valid: true, - } +func (b *clientHelloBuffer) growBy(size int) { + buf := make([]byte, b.len+size) + copy(buf, b.data[:b.len]) + b.data = buf } -// ReadFrom reads all the data from r and appends it to this buffer until a complete Client Hello packet has been -// received, or r returns EOF or error. It returns the number of bytes read. Any error except EOF encountered during -// the read is also returned. -// -// You can call ReadFrom multiple times if r doesn't provide enough data to build a complete Client Hello packet. -// Call HasFullyReceived to check whether a complete Client Hello packet has been constructed. -func (b *clientHelloBuffer) ReadFrom(r io.Reader) (n int64, err error) { +// Write appends p to the buffer and returns the number of bytes actually used. +// If this data completes a valid TLS Client Hello, it returns errTLSClientHelloFullyReceived. +// If an invalid TLS Client Hello message is detected, it returns the error errInvalidTLSClientHello. +// If all bytes in p have been used and the buffer still requires more data to build a complete TLS Client Hello +// message, it returns (len(p), nil). +func (b *clientHelloBuffer) Write(p []byte) (n int, err error) { if !b.valid { return 0, errInvalidTLSClientHello } - if b.len < tlsRecordHeaderSize { - m, e := b.readHeaderFrom(r) - n += int64(m) - if err = e; err == io.EOF { - return n, nil - } - if err != nil { - return - } - } - - if b.len < b.recordLen+tlsRecordHeaderSize { - m, e := b.readContentFrom(r) - n += int64(m) - if err = e; err == io.EOF { - return n, nil - } - } - return -} + defer func() { b.valid = !errors.Is(err, errInvalidTLSClientHello) }() -// WriteTo writes all data from this buffer to w until there's no more data or when an error occurs. -// It returns the number of bytes written. Any error encountered during the read is also returned. -// -// Note that the number of bytes written includes both the data read by ReadFrom and any additional headers. -// If you only want to know how many bytes from the last ReadFrom were written, check BytesOverlapped. -func (b *clientHelloBuffer) WriteTo(w io.Writer) (n int64, err error) { - if b.len > 0 { - m, e := w.Write(b.data[:b.len]) - n = int64(m) - if err = e; err != nil { + for b.toRead > 0 { + if len(p) == 0 { return } - // all bytes should have been written, by definition of Write method in io.Writer - if m != b.len { - err = io.ErrShortWrite - } - } - return -} - -// HasFullyReceived returns whether a complete TLS Client Hello packet has been assembled. -func (b *clientHelloBuffer) HasFullyReceived() bool { - return b.valid && b.recordLen > 0 && b.len >= b.recordLen+tlsRecordHeaderSize -} - -// BytesOverlapped returns the number of bytes actually copied from the io.Reader in ReadFrom(r) -// to io.Writer in WriteTo, ignoring any extra headers added by Split. -// -// Here's an example explaining it further: -// -// _, _ := buf.ReadFrom([]byte{1,2}) // {1,2} are appended to buf -// rn, _ := buf.ReadFrom([]byte{3,4,5,6}) // rn == 3, {3,4,5} are appended to buf -// buf.Split(2) // will add some additional header bytes -// // now assume buf contains {1,2,h,h,h,h,h,3,4,5} -// wn, _ := buf.WriteTo(w) // wn == 8, {1,2,h,h,h,h,h,3} are written to w -// n := buf.BytesOverlapped(rn, wn) // n == 1, because only byte {3} comes from the last ReadFrom -func (b *clientHelloBuffer) BytesOverlapped(rn, wn int64) int { - // ndata = 12: 1 2 3 4 h h h h h 5 6 7 - // rn = 5: | | | | | - // wn = 6: | | | | | | - // overlap == 2: ^ ^ - // wn & h: x x x x | | N N N - if wn < int64(b.split) { - // add all 5 header bytes to wn when splitted and wn doesn't overlap with h - // if no splitting, this condition will never be satifsfied because wn always >= 0 - wn += tlsRecordHeaderSize - } else if b.split > 0 && wn < int64(b.split+tlsRecordHeaderSize) { - // fill all non-overlapped h bytes to wn (bytes marked as N above) when wn partially overlaps with h - wn = int64(b.split + tlsRecordHeaderSize) - } - - // now both wn and n contain either a 5-byte header or no header at all - // the header bytes get cancelled out in the subtraction (wn - ndata) below - // rn + wn = (left+overlap) + (overlap+right) = (left+overlap+right) + overlap = ndata + overlap - if overlap := int(rn) + int(wn) - b.len; overlap >= 0 { - return overlap - } - return 0 -} - -// Content returns the Client Hello packet content (without the 5 bytes header). -// It might return an incomplete content, the caller needs to make sure HasFullyReceived before calling this function. -func (b *clientHelloBuffer) Content() []byte { - if b.len <= tlsRecordHeaderSize { - return []byte{} - } - return b.data[tlsRecordHeaderSize:b.len] -} - -// Split fragments the Client Hello packet into two TLS records at the specified 0-based splitBytes: -// [:splitBytes] and [splitBytes:]. Any necessary headers will be added to this buffer. -// -// If the packet has already be splitted before, a non-nil error and returned. -// If the split index is ≤ 0 or ≥ the total length, do nothing. -func (b *clientHelloBuffer) Split(splitBytes int) error { - if b.split > 0 { - return errors.New("packet has already been fragmented") - } - if !b.HasFullyReceived() || b.len != b.recordLen+tlsRecordHeaderSize { - return errors.New("incomplete packet cannot be fragmented") - } - if splitBytes <= 0 || splitBytes >= b.recordLen { - return nil - } - _ = b.data[b.len+tlsRecordHeaderSize-1] // bounds check to guarantee safety of writes below - - // the 2nd record starting point (including header), and move the 2nd record content 5 bytes to the right - sz2 := b.recordLen - splitBytes - b.split = tlsRecordHeaderSize + splitBytes - b.len += tlsRecordHeaderSize - - if copy(b.data[b.split+tlsRecordHeaderSize:b.len], b.data[b.split:]) != sz2 { - return errors.New("failed to split the second record") - } - - putTLSClientHelloHeader(b.data[0:], uint16(splitBytes)) - putTLSClientHelloHeader(b.data[b.split:], uint16(sz2)) - return nil -} - -// readHeaderFrom read a 5 bytes TLS Client Hello header from r into b.data[0:5]. -func (b *clientHelloBuffer) readHeaderFrom(r io.Reader) (n int, err error) { - if b.len >= tlsRecordHeaderSize { - return 0, errors.New("header has already been read") - } - if len(b.data) < tlsRecordHeaderSize { - return 0, errors.New("insufficient buffer to hold the header") - } - - prevLen := b.len - for err == nil && b.len < tlsRecordHeaderSize { - m, e := r.Read(b.data[b.len:tlsRecordHeaderSize]) - err = e - n += m - b.len += m - } - - if prevLen < tlsRecordWithTypeSize && b.len >= tlsRecordWithTypeSize { - if !isTLSRecordTypeHandshake(b.data) { - b.valid = false - err = errors.Join(err, fmt.Errorf("not a handshake record: %w", errInvalidTLSClientHello)) + sz := b.toRead + if len(p) < sz { + sz = len(p) } - } + copy(b.data[b.len:], p[:sz]) + n += sz + b.len += sz + b.toRead -= sz + p = p[sz:] + + // check whether message is valid according to the bytes just read + switch b.len { + case tlsRecordWithTypeSize: // 1 + if typ := getRecordType(b.data); typ != recordTypeHandshake { + return n, fmt.Errorf("record type %d is not handshake: %w", typ, errInvalidTLSClientHello) + } + b.toRead = tlsRecordWithVersionHeaderSize - tlsRecordWithTypeSize // +2 - if prevLen < tlsRecordWithVersionHeaderSize && b.len >= tlsRecordWithVersionHeaderSize { - if !isValidTLSProtocolVersion(b.data) { - b.valid = false - err = errors.Join(err, fmt.Errorf("not a valid TLS version: %w", errInvalidTLSClientHello)) - } - } + case tlsRecordWithVersionHeaderSize: // 3 + if ver := getTLSVersion(b.data); !isValidTLSVersion(ver) { + return n, fmt.Errorf("%#04x is not a valid TLS version: %w", ver, errInvalidTLSClientHello) + } + b.toRead = recordHeaderLen - tlsRecordWithVersionHeaderSize // +2 - if prevLen < tlsRecordHeaderSize && b.len >= tlsRecordHeaderSize { - if rl := recordLen(b.data); !isValidRecordLenForHandshake(rl) { - b.valid = false - err = errors.Join(err, fmt.Errorf("record length out of range: %w", errInvalidTLSClientHello)) - } else { - b.recordLen = int(rl) - // allocate space for 2 headers and 1 content (might be splitted into two contents) - buf := make([]byte, b.recordLen+tlsRecordHeaderSize*2) - if copy(buf, b.data[:tlsRecordHeaderSize]) != tlsRecordHeaderSize { - err = errors.Join(err, errors.New("failed to copy header data")) - } else { - b.data = buf + case recordHeaderLen: // 5 + if b.toRead = int(getMsgLen(b.data)); !isValidMsgLenForHandshake(uint16(b.toRead)) { + return n, fmt.Errorf("message length %v out of range: %w", b.toRead, errInvalidTLSClientHello) } + b.growBy(b.toRead) } } - return -} -// readContentFrom read a recordLen bytes TLS Client Hello record content from r into b.data[5:5+recordLen]. -func (b *clientHelloBuffer) readContentFrom(r io.Reader) (n int, err error) { - fullsz := tlsRecordHeaderSize + b.recordLen - if b.len >= fullsz { - return 0, errors.New("content has already been read") - } - if len(b.data) < fullsz { - return 0, errors.New("insufficient buffer to hold the content") - } + return n, errTLSClientHelloFullyReceived +} - for err == nil && b.len < fullsz { - m, e := r.Read(b.data[b.len:fullsz]) - err = e - n += m - b.len += m - } - return +// ReadFrom reads all the data from r and appends it to this buffer until a complete Client Hello packet has been +// received, or r returns EOF or error. It returns the number of bytes read. Any error except EOF encountered during +// the read is also returned. +// +// If this buffer completes a valid TLS Client Hello, it returns errTLSClientHelloFullyReceived. +// If an invalid TLS Client Hello message is detected, it returns the error errInvalidTLSClientHello. +// If this buffer still requires more data to build a complete TLS Client Hello message, it returns nil error. +// +// You can call ReadFrom multiple times if r doesn't provide enough data to build a complete Client Hello packet. +func (b *clientHelloBuffer) ReadFrom(r io.Reader) (n int64, err error) { + return 0, errors.New("to be implemented") } diff --git a/transport/tlsfrag/tls.go b/transport/tlsfrag/tls.go new file mode 100644 index 00000000..df915c89 --- /dev/null +++ b/transport/tlsfrag/tls.go @@ -0,0 +1,116 @@ +// Copyright 2023 Jigsaw Operations LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tlsfrag + +import ( + "encoding/binary" +) + +// This file contains helper functions and constants for TLS Client Hello message. + +type recordType uint8 +type tlsVersion uint16 + +// TLS record layout from [RFC 8446]: +// +// +-------------+ 0 +// | RecordType | +// +-------------+ 1 +// | Protocol | +// | Version | +// +-------------+ 3 +// | Record | +// | Length | +// +-------------+ 5 +// | Message | +// | Data | +// | ... | +// +-------------+ Message Length + 5 +// +// RecordType := invalid(0) | handshake(22) | application_data(23) | ... +// Protocol Version (deprecated) := 0x0301 ("TLS 1.0") | 0x0303 ("TLS 1.2" & "TLS 1.3") | 0x0302 ("TLS 1.1") +// 0 < Message Length (of handshake) ≤ 2^14 +// 0 ≤ Message Length (of application_data) ≤ 2^14 +// +// [RFC 8446]: https://datatracker.ietf.org/doc/html/rfc8446#section-5.1 +const ( + tlsRecordWithTypeSize = 1 // the minimum size that contains record type + tlsRecordWithVersionHeaderSize = 3 // the minimum size that contains protocol version + + recordHeaderLen = 5 + maxMsgLen = 1 << 14 + + recordTypeHandshake recordType = 22 + + versionTLS10 tlsVersion = 0x0301 + versionTLS11 tlsVersion = 0x0302 + versionTLS12 tlsVersion = 0x0303 + versionTLS13 tlsVersion = 0x0304 +) + +// getRecordType gets the TLS record type from the TLS header hdr[0]. This function will panic if len(hdr) < 1. +func getRecordType(hdr []byte) recordType { + return recordType(hdr[0]) +} + +// putRecordType puts the TLS record type to the TLS header hdr[0]. This function will panic if len(hdr) < 1. +func putRecordType(hdr []byte, typ recordType) { + hdr[0] = byte(typ) +} + +// getTLSVersion gets the TLS version from the TLS header hdr[1:3]. This function will panic if len(hdr) < 3. +func getTLSVersion(hdr []byte) tlsVersion { + return tlsVersion(binary.BigEndian.Uint16(hdr[1:])) +} + +// putTLSVersion puts the TLS version to the TLS header hdr[1:3]. This function will panic if len(hdr) < 3. +func putTLSVersion(hdr []byte, ver tlsVersion) { + binary.BigEndian.PutUint16(hdr[1:], uint16(ver)) +} + +// getMsgLen gets the TLS message length from the TLS header hdr[3:5]. This function will panic if len(hdr) < 5. +func getMsgLen(hdr []byte) uint16 { + return binary.BigEndian.Uint16(hdr[3:]) +} + +// putMsgLen puts the TLS message length to the TLS header hdr[3:5]. This function will panic if len(hdr) < 5. +func putMsgLen(hdr []byte, len uint16) { + binary.BigEndian.PutUint16(hdr[3:], len) +} + +// isValidTLSProtocolVersion determines whether ver is a valid TLS version according to RFC: +// +// """ +// legacy_record_version: +// MUST be set to 0x0303 for all records generated by a TLS 1.3 implementation other than an initial ClientHello, +// where it MAY also be 0x0301 for compatibility purposes. This field is deprecated and MUST be ignored for all +// purposes. Previous versions of TLS would use other values in this field under some circumstances. +// """ +func isValidTLSVersion(ver tlsVersion) bool { + return ver == versionTLS10 || ver == versionTLS11 || ver == versionTLS12 || ver == versionTLS13 +} + +// isValidRecordLenForHandshake checks whether 0 < len ≤ 2^14. +func isValidMsgLenForHandshake(len uint16) bool { + return 0 < len && len <= maxMsgLen +} + +// This function will panic if len(hdr) < 5. +func putTLSClientHelloHeader(hdr []byte, recordLen uint16) { + _ = hdr[recordHeaderLen-1] // bounds check to guarantee safety of writes below + putRecordType(hdr, recordTypeHandshake) + putTLSVersion(hdr, versionTLS10) + putMsgLen(hdr, recordLen) +} diff --git a/transport/tlsfrag/writer.go b/transport/tlsfrag/writer.go index f00c38e7..229fe840 100644 --- a/transport/tlsfrag/writer.go +++ b/transport/tlsfrag/writer.go @@ -25,9 +25,11 @@ import ( // are not modified and are directly transmitted through the base [io.Writer]. type clientHelloFragWriter struct { base io.Writer - done bool + done bool // indicates all splitted rcds have been already written to base frag FragFunc - buf *clientHelloBuffer + + buf *clientHelloBuffer // the buffer containing and parsing a TLS Client Hello record + rcds *bytes.Buffer // the buffer containing splitted records what will be written to base } // clientHelloFragReaderFrom serves as an optimized version of clientHelloFragWriter when the base [io.Writer] also @@ -68,57 +70,31 @@ func newClientHelloFragWriter(base io.Writer, frag FragFunc) (io.Writer, error) // Write implements io.Writer.Write. It attempts to split the data received in the first one or more Write call(s) // into two TLS records if the data corresponds to a TLS Client Hello record. -// -// Internally, this function maintains a state machine with the following states: -// - S: reading the first client hello record and appending the data to w.buf -// - F: the first client hello record has been read, fragmenting and writing to w.base -// - T: forwarding all remaining packets without modification -// -// Here is the transition graph: -// -// S ----(full handshake read)----> F -----> T -// | ^ -// | | -// +-----(invalid TLS handshake)-------------+ -func (w *clientHelloFragWriter) Write(p []byte) (written int, err error) { - // T: optimize to have fewer comparisons for the most common case. +func (w *clientHelloFragWriter) Write(p []byte) (n int, err error) { if w.done { return w.base.Write(p) } - - // S - nr, e := w.buf.ReadFrom(bytes.NewBuffer(p)) - - // S -> T - if errors.Is(e, errInvalidTLSClientHello) { - goto FlushBufAndDone - } - - // S < x < F, wait for the next write - if e != nil || !w.buf.HasFullyReceived() { - return int(nr), e + if w.rcds != nil { + if _, err = w.flushRecords(); err != nil { + return + } + return w.base.Write(p) } - // F - if err = w.buf.Split(w.frag(w.buf.Content())); err != nil { - return int(nr), err + if n, err = w.buf.Write(p); err != nil { + if errors.Is(err, errTLSClientHelloFullyReceived) { + w.splitBufToRecords() + } else { + w.copyBufToRecords() + } + // recursively flush w.rcds and write the remaining content + m, e := w.Write(p[n:]) + return n + m, e } - // * -> T (err must be nil) -FlushBufAndDone: - w.done = true - nw, e := w.buf.WriteTo(w.base) - written += w.buf.BytesOverlapped(nr, nw) - w.buf = nil // allows the GC to recycle the memory - - // If WriteTo failed, no need to write more data - if err = e; err != nil { - return + if n < len(p) { + return n, io.ErrShortWrite } - - m, e := w.base.Write(p[nr:]) - written += m - err = e return } @@ -129,36 +105,65 @@ FlushBufAndDone: // If the first packet is not a valid TLS Client Hello, everything from r gets copied to the base io.Writer as is. // // It returns the number of bytes read. Any error except EOF encountered during the read is also returned. -// -// Internally, it uses a similar state machine to the one mentioned in w.Write. But the transition is simplier because -// we expect r containing all the data (while the first packet might be consisted of multiple Writes in Write). func (w *clientHelloFragReaderFrom) ReadFrom(r io.Reader) (n int64, err error) { - // T if w.done { return w.baseRF.ReadFrom(r) } + if w.rcds != nil { + if _, err = w.flushRecords(); err != nil { + return + } + return w.baseRF.ReadFrom(r) + } - // S & F - nr, e := w.buf.ReadFrom(r) - if err = e; err == nil && w.buf.HasFullyReceived() { - err = w.buf.Split(w.frag(w.buf.Content())) - } else if errors.Is(err, errInvalidTLSClientHello) { - err = nil + if n, err = w.buf.ReadFrom(r); err != nil { + if errors.Is(err, errTLSClientHelloFullyReceived) { + w.splitBufToRecords() + } else { + w.copyBufToRecords() + } + // recursively flush w.rcds and read the remaining content from r + m, e := w.ReadFrom(r) + return n + m, e } + return +} - // * -> T (err might be non-nil, but we still need to flush data to w.base) - w.done = true - nw, e := w.buf.WriteTo(w.base) - n += int64(w.buf.BytesOverlapped(nr, nw)) +// copyBuf copies w.buf into w.rcds. +func (w *clientHelloFragWriter) copyBufToRecords() { + w.rcds = bytes.NewBuffer(w.buf.Bytes()) w.buf = nil // allows the GC to recycle the memory +} - // If WriteTo failed, no need to write more data - if err = errors.Join(err, e); e != nil { +// splitBuf splits w.buf into two records and put them into w.rcds. +func (w *clientHelloFragWriter) splitBufToRecords() { + content := w.buf.Bytes()[recordHeaderLen:] + split := w.frag(content) + if split <= 0 || split >= len(content) { + w.copyBufToRecords() return } - m, e := w.baseRF.ReadFrom(r) - n += m - err = e - return + header := make([]byte, recordHeaderLen) + w.rcds = bytes.NewBuffer(make([]byte, 0, w.buf.Len()+recordHeaderLen)) + + putTLSClientHelloHeader(header, uint16(split)) + w.rcds.Write(header) + w.rcds.Write(content[:split]) + + putTLSClientHelloHeader(header, uint16(len(content)-split)) + w.rcds.Write(header) + w.rcds.Write(content[split:]) + + w.buf = nil // allows the GC to recycle the memory +} + +// flushRecords writes all bytes from w.rcds to base. +func (w *clientHelloFragWriter) flushRecords() (int, error) { + n, err := io.Copy(w.base, w.rcds) + if w.rcds.Len() == 0 { + w.rcds = nil // allows the GC to recycle the memory + w.done = true + } + return int(n), err } From 653b8558520f730c53e843500f778d9d27431f18 Mon Sep 17 00:00:00 2001 From: jyyi1 Date: Thu, 16 Nov 2023 11:46:41 -0500 Subject: [PATCH 06/23] fix test cases for buffer --- transport/tlsfrag/buffer.go | 113 +++++++++++-------- transport/tlsfrag/buffer_test.go | 185 ++++++++++++++++--------------- 2 files changed, 164 insertions(+), 134 deletions(-) diff --git a/transport/tlsfrag/buffer.go b/transport/tlsfrag/buffer.go index 4279c590..e7184bb5 100644 --- a/transport/tlsfrag/buffer.go +++ b/transport/tlsfrag/buffer.go @@ -32,10 +32,10 @@ var ( // clientHelloBuffer is a byte buffer used to receive and buffer a TLS Client Hello packet. type clientHelloBuffer struct { - data []byte // the buffer that hosts both header and content, cap: 5 -> 5+len(content) - len int // the number of bytes that have already been read into data + data []byte // the buffer that hosts both header and content, len: 5 -> 5+len(content) + len int // the actual bytes that have been read into data valid bool // indicate whether the content in data is a valid TLS Client Hello record - toRead int // the number of bytes to read next, e.g. 1 -> 2 -> 2 -> len(content) + toRead int // the number of bytes to read next, e.g. 5 -> len(content) } var _ io.Writer = (*clientHelloBuffer)(nil) @@ -46,9 +46,8 @@ func newClientHelloBuffer() *clientHelloBuffer { // Allocate the 5 bytes header first, and then reallocate it to contain the entire packet later return &clientHelloBuffer{ data: make([]byte, recordHeaderLen), - len: 0, valid: true, - toRead: tlsRecordWithTypeSize, + toRead: recordHeaderLen, } } @@ -62,12 +61,6 @@ func (b *clientHelloBuffer) Bytes() []byte { return b.data[:b.len] } -func (b *clientHelloBuffer) growBy(size int) { - buf := make([]byte, b.len+size) - copy(buf, b.data[:b.len]) - b.data = buf -} - // Write appends p to the buffer and returns the number of bytes actually used. // If this data completes a valid TLS Client Hello, it returns errTLSClientHelloFullyReceived. // If an invalid TLS Client Hello message is detected, it returns the error errInvalidTLSClientHello. @@ -78,46 +71,26 @@ func (b *clientHelloBuffer) Write(p []byte) (n int, err error) { return 0, errInvalidTLSClientHello } - defer func() { b.valid = !errors.Is(err, errInvalidTLSClientHello) }() - - for b.toRead > 0 { - if len(p) == 0 { - return - } - - sz := b.toRead - if len(p) < sz { - sz = len(p) - } - copy(b.data[b.len:], p[:sz]) - n += sz - b.len += sz - b.toRead -= sz - p = p[sz:] - - // check whether message is valid according to the bytes just read - switch b.len { - case tlsRecordWithTypeSize: // 1 - if typ := getRecordType(b.data); typ != recordTypeHandshake { - return n, fmt.Errorf("record type %d is not handshake: %w", typ, errInvalidTLSClientHello) - } - b.toRead = tlsRecordWithVersionHeaderSize - tlsRecordWithTypeSize // +2 + for b.len < len(b.data) && len(p) > 0 { + m := copy(b.data[b.len:], p) + n += m + b.len += m + p = p[m:] - case tlsRecordWithVersionHeaderSize: // 3 - if ver := getTLSVersion(b.data); !isValidTLSVersion(ver) { - return n, fmt.Errorf("%#04x is not a valid TLS version: %w", ver, errInvalidTLSClientHello) + if b.len == recordHeaderLen { + if err = b.validateTLSClientHello(); err != nil { + return } - b.toRead = recordHeaderLen - tlsRecordWithVersionHeaderSize // +2 - - case recordHeaderLen: // 5 - if b.toRead = int(getMsgLen(b.data)); !isValidMsgLenForHandshake(uint16(b.toRead)) { - return n, fmt.Errorf("message length %v out of range: %w", b.toRead, errInvalidTLSClientHello) - } - b.growBy(b.toRead) + buf := make([]byte, recordHeaderLen+getMsgLen(b.data)) + copy(buf, b.data) + b.data = buf } } - return n, errTLSClientHelloFullyReceived + if b.len == len(b.data) { + err = errTLSClientHelloFullyReceived + } + return } // ReadFrom reads all the data from r and appends it to this buffer until a complete Client Hello packet has been @@ -130,5 +103,51 @@ func (b *clientHelloBuffer) Write(p []byte) (n int, err error) { // // You can call ReadFrom multiple times if r doesn't provide enough data to build a complete Client Hello packet. func (b *clientHelloBuffer) ReadFrom(r io.Reader) (n int64, err error) { - return 0, errors.New("to be implemented") + if !b.valid { + return 0, errInvalidTLSClientHello + } + + for b.len < len(b.data) && err == nil { + m, e := r.Read(b.data[b.len:]) + n += int64(m) + b.len += m + err = e + + if b.len == recordHeaderLen { + if e := b.validateTLSClientHello(); e != nil { + if err == io.EOF { + err = nil + } + err = errors.Join(err, e) + return + } + buf := make([]byte, recordHeaderLen+getMsgLen(b.data)) + copy(buf, b.data) + b.data = buf + } + } + + if err == io.EOF { + err = nil + } + if b.len == len(b.data) { + err = errors.Join(err, errTLSClientHelloFullyReceived) + } + return +} + +func (b *clientHelloBuffer) validateTLSClientHello() error { + if typ := getRecordType(b.data); typ != recordTypeHandshake { + b.valid = false + return fmt.Errorf("record type %d is not handshake: %w", typ, errInvalidTLSClientHello) + } + if ver := getTLSVersion(b.data); !isValidTLSVersion(ver) { + b.valid = false + return fmt.Errorf("%#04x is not a valid TLS version: %w", ver, errInvalidTLSClientHello) + } + if len := getMsgLen(b.data); !isValidMsgLenForHandshake(len) { + b.valid = false + return fmt.Errorf("message length %v out of range: %w", len, errInvalidTLSClientHello) + } + return nil } diff --git a/transport/tlsfrag/buffer_test.go b/transport/tlsfrag/buffer_test.go index d911c515..ec6e7994 100644 --- a/transport/tlsfrag/buffer_test.go +++ b/transport/tlsfrag/buffer_test.go @@ -16,11 +16,67 @@ package tlsfrag import ( "bytes" + "net" "testing" "github.com/stretchr/testify/require" ) +// Test Write valid Client Hello to the buffer. +func TestWriteValidClientHello(t *testing.T) { + for _, tc := range validClientHelloTestCases() { + buf := newClientHelloBuffer() + + totalExpectedBytes := []byte{} + for k, pkt := range tc.pkts { + n, err := buf.Write(pkt) + if k < tc.expectLastPkt { + require.NoError(t, err, tc.msg+": pkt-%d", k) + } else { + require.ErrorIs(t, err, errTLSClientHelloFullyReceived, tc.msg+": pkt-%d", k) + } + require.Equal(t, len(pkt)-len(tc.expectRemaining[k]), n, tc.msg+": pkt-%d", k) + require.Equal(t, tc.expectRemaining[k], pkt[n:], tc.msg+": pkt-%d", k) + + totalExpectedBytes = append(totalExpectedBytes, pkt[:n]...) + require.Equal(t, len(totalExpectedBytes), buf.Len(), tc.msg+": pkt-%d", k) + require.Equal(t, totalExpectedBytes, buf.Bytes(), tc.msg+": pkt-%d", k) + } + require.Equal(t, len(tc.expectTotalPkt), buf.Len(), tc.msg) + require.Equal(t, tc.expectTotalPkt, buf.Bytes(), tc.msg) + } +} + +// Test ReadFrom Reader(s) containing valid Client Hello. +func TestReadFromValidClientHello(t *testing.T) { + for _, tc := range validClientHelloTestCases() { + buf := newClientHelloBuffer() + + totalExpectedBytes := []byte{} + for k, pkt := range tc.pkts { + r := bytes.NewBuffer(pkt) + require.Equal(t, len(pkt), r.Len(), tc.msg+": pkt-%d", k) + + n, err := buf.ReadFrom(r) + if k < tc.expectLastPkt { + require.NoError(t, err, tc.msg+": pkt-%d", k) + } else { + require.ErrorIs(t, err, errTLSClientHelloFullyReceived, tc.msg+": pkt-%d", k) + } + require.Equal(t, len(pkt)-len(tc.expectRemaining[k]), int(n), tc.msg+": pkt-%d", k) + require.Equal(t, tc.expectRemaining[k], pkt[n:], tc.msg+": pkt-%d", k) + + totalExpectedBytes = append(totalExpectedBytes, pkt[:n]...) + require.Equal(t, len(totalExpectedBytes), buf.Len(), tc.msg+": pkt-%d", k) + require.Equal(t, totalExpectedBytes, buf.Bytes(), tc.msg+": pkt-%d", k) + require.Equal(t, len(tc.expectRemaining[k]), r.Len(), tc.msg+": pkt-%d", k) + require.Equal(t, tc.expectRemaining[k], r.Bytes(), tc.msg+": pkt-%d", k) + } + require.Equal(t, len(tc.expectTotalPkt), buf.Len(), tc.msg) + require.Equal(t, tc.expectTotalPkt, buf.Bytes(), tc.msg) + } +} + // Example TLS Client Hello packet copied from https://tls13.xargs.org/#client-hello // Total Len = 253, ContentLen = 253 - 5 = 248 var exampleTLS13ClientHello = []byte{ @@ -39,93 +95,48 @@ var exampleTLS13ClientHello = []byte{ 0x21, 0xa2, 0x8e, 0x3b, 0x75, 0xe9, 0x65, 0xd0, 0xd2, 0xcd, 0x16, 0x62, 0x54, } -// Test ReadFrom a Reader containing exactly one Client Hello record. -func TestReadFromClientHelloSingleReader(t *testing.T) { - buf := newClientHelloBuffer() - r := bytes.NewBuffer(exampleTLS13ClientHello) - require.Equal(t, 253, r.Len()) - - n, err := buf.ReadFrom(r) - require.NoError(t, err) - require.Equal(t, 253, int(n)) - require.Equal(t, 0, r.Len()) - - require.True(t, buf.HasFullyReceived()) - require.Equal(t, exampleTLS13ClientHello[5:], buf.Content()) - require.Equal(t, 248, len(buf.Content())) +type validClientHelloCase struct { + msg string + pkts net.Buffers + expectTotalPkt []byte + expectLastPkt int // the index of the expected last packet + expectRemaining net.Buffers } -// Test ReadFrom multiple Readers containing exactly one Client Hello record. -func TestReadFromClientHelloMultipleReaders(t *testing.T) { - buf := newClientHelloBuffer() - r1 := bytes.NewBuffer(exampleTLS13ClientHello[:2]) - r2 := bytes.NewBuffer(exampleTLS13ClientHello[2:123]) - r3 := bytes.NewBuffer(exampleTLS13ClientHello[123:]) - - require.Equal(t, 2, r1.Len()) - n, err := buf.ReadFrom(r1) - require.NoError(t, err) - require.Equal(t, 2, int(n)) - require.Equal(t, 0, r1.Len()) - require.False(t, buf.HasFullyReceived()) - - require.Equal(t, 123-2, r2.Len()) - n, err = buf.ReadFrom(r2) - require.NoError(t, err) - require.Equal(t, 123-2, int(n)) - require.Equal(t, 0, r2.Len()) - require.False(t, buf.HasFullyReceived()) - - require.Equal(t, 253-123, r3.Len()) - n, err = buf.ReadFrom(r3) - require.NoError(t, err) - require.Equal(t, 253-123, int(n)) - require.Equal(t, 0, r3.Len()) - - require.True(t, buf.HasFullyReceived()) - require.Equal(t, exampleTLS13ClientHello[5:], buf.Content()) - require.Equal(t, 248, len(buf.Content())) -} - -// Test ReadFrom a Reader containing Client Hello and 8 more extra bytes. -func TestReadFromClientHelloExtraBytesSingleReader(t *testing.T) { - buf := newClientHelloBuffer() - r := bytes.NewBuffer(append(exampleTLS13ClientHello, 0x88, 0x87, 0x86, 0x85, 0x84, 0x83, 0x82, 0x81)) - require.Equal(t, 261, r.Len()) - - n, err := buf.ReadFrom(r) - require.NoError(t, err) - require.Equal(t, 253, int(n)) - require.Equal(t, 8, r.Len()) - - require.True(t, buf.HasFullyReceived()) - require.Equal(t, exampleTLS13ClientHello[5:], buf.Content()) - require.Equal(t, 248, len(buf.Content())) - require.Equal(t, []byte{0x88, 0x87, 0x86, 0x85, 0x84, 0x83, 0x82, 0x81}, r.Bytes()) -} - -// Test ReadFrom multiple Readers containing Client Hello and 8 more extra bytes. -func TestReadFromClientHelloExtraBytesMultipleReaders(t *testing.T) { - buf := newClientHelloBuffer() - pkt := append(exampleTLS13ClientHello, 0x88, 0x87, 0x86, 0x85, 0x84, 0x83, 0x82, 0x81) - r1 := bytes.NewBuffer(pkt[:123]) - r2 := bytes.NewBuffer(pkt[123:]) - - require.Equal(t, 123, r1.Len()) - n, err := buf.ReadFrom(r1) - require.NoError(t, err) - require.Equal(t, 123, int(n)) - require.Equal(t, 0, r1.Len()) - require.False(t, buf.HasFullyReceived()) - - require.Equal(t, 261-123, r2.Len()) - n, err = buf.ReadFrom(r2) - require.NoError(t, err) - require.Equal(t, 261-123-8, int(n)) - require.Equal(t, 8, r2.Len()) - - require.True(t, buf.HasFullyReceived()) - require.Equal(t, exampleTLS13ClientHello[5:], buf.Content()) - require.Equal(t, 248, len(buf.Content())) - require.Equal(t, []byte{0x88, 0x87, 0x86, 0x85, 0x84, 0x83, 0x82, 0x81}, r2.Bytes()) +// generating test cases that contain one valid Client Hello +func validClientHelloTestCases() []validClientHelloCase { + return []validClientHelloCase{ + { + msg: "full client hello in single buffer", + pkts: [][]byte{exampleTLS13ClientHello}, + expectTotalPkt: exampleTLS13ClientHello, + expectLastPkt: 0, + expectRemaining: [][]byte{{}}, + }, + { + msg: "full client hello with extra bytes", + pkts: [][]byte{append(exampleTLS13ClientHello, 0x88, 0x87, 0x86, 0x85, 0x84, 0x83, 0x82, 0x81)}, + expectTotalPkt: exampleTLS13ClientHello, + expectLastPkt: 0, + expectRemaining: [][]byte{{0x88, 0x87, 0x86, 0x85, 0x84, 0x83, 0x82, 0x81}}, + }, + { + msg: "client hello in three buffers", + pkts: [][]byte{exampleTLS13ClientHello[:2], exampleTLS13ClientHello[2:123], exampleTLS13ClientHello[123:]}, + expectTotalPkt: exampleTLS13ClientHello, + expectLastPkt: 2, + expectRemaining: [][]byte{{}, {}, {}}, + }, + { + msg: "client hello in three buffers with extra bytes", + pkts: [][]byte{ + exampleTLS13ClientHello[:2], + exampleTLS13ClientHello[2:123], + append(exampleTLS13ClientHello[123:], 0x88, 0x87, 0x86, 0x85, 0x84, 0x83, 0x82, 0x81), + }, + expectTotalPkt: exampleTLS13ClientHello, + expectLastPkt: 2, + expectRemaining: [][]byte{{}, {}, {0x88, 0x87, 0x86, 0x85, 0x84, 0x83, 0x82, 0x81}}, + }, + } } From 2ed83041cd5f23c083b538644245b3563c1490e6 Mon Sep 17 00:00:00 2001 From: jyyi1 Date: Thu, 16 Nov 2023 12:11:55 -0500 Subject: [PATCH 07/23] simplify stream_dialer by extracting out TLS dispatching --- transport/tlsfrag/stream_dialer.go | 125 +--------------- transport/tlsfrag/stream_dialer_test.go | 187 ------------------------ transport/tlsfrag/tls.go | 23 +-- transport/tlsfrag/writer.go | 6 +- 4 files changed, 10 insertions(+), 331 deletions(-) diff --git a/transport/tlsfrag/stream_dialer.go b/transport/tlsfrag/stream_dialer.go index 38a5991e..7c3e1563 100644 --- a/transport/tlsfrag/stream_dialer.go +++ b/transport/tlsfrag/stream_dialer.go @@ -17,10 +17,6 @@ package tlsfrag import ( "context" "errors" - "fmt" - "net" - "strconv" - "strings" "github.com/Jigsaw-Code/outline-sdk/transport" ) @@ -30,7 +26,6 @@ import ( type tlsFragDialer struct { dialer transport.StreamDialer frag FragFunc - config *DialerConfiguration } // Compilation guard against interface implementation @@ -42,19 +37,7 @@ var _ transport.StreamDialer = (*tlsFragDialer)(nil) // If the returned index is either ≤ 0 or ≥ len(record), no fragmentation will occur. // // [handshake record]: https://datatracker.ietf.org/doc/html/rfc8446#section-5.1 -type FragFunc func(record []byte) int - -// DialerConfiguration is an internal type used to configure the [transport.StreamDialer] created by -// [NewStreamDialerFunc]. You don't need to work with it directly. Instead, use the provided configuration functions -// like [WithTLSHostPortList]. -type DialerConfiguration struct { - addrs []*tlsAddrEntry -} - -// DialerConfigurer updates the settings in the internal DialerConfiguration object. You can use the configuration -// functions such as [WithTLSHostPortList] to create configurers and then pass them to NewStreamDialerFunc to create a -// [transport.StreamDialer] with your desired configuration. -type DialerConfigurer func(*DialerConfiguration) error +type FragFunc func(record []byte) (n int) // NewStreamDialerFunc creates a [transport.StreamDialer] that intercepts the initial [TLS Client Hello] // [handshake record] and splits it into two separate records before sending them. The split point is determined by the @@ -62,78 +45,26 @@ type DialerConfigurer func(*DialerConfiguration) error // using the base dialer. Following the fragmented Client Hello, all subsequent data is passed through directly without // modification. // -// NewStreamDialerFunc allows specifying additional options to customize its behavior. By default, if no options are -// specified, the fragmentation only affects TLS Client Hello messages targeting port 443. All other network traffic, -// including non-TLS or non-Client Hello messages, or those targeting other ports, are passed through without any -// modification. -// // [TLS Client Hello]: https://datatracker.ietf.org/doc/html/rfc8446#section-4.1.2 // [handshake record]: https://datatracker.ietf.org/doc/html/rfc8446#section-5.1 -func NewStreamDialerFunc(base transport.StreamDialer, frag FragFunc, options ...DialerConfigurer) (transport.StreamDialer, error) { +func NewStreamDialerFunc(base transport.StreamDialer, frag FragFunc) (transport.StreamDialer, error) { if base == nil { return nil, errors.New("base dialer must not be nil") } if frag == nil { return nil, errors.New("frag function must not be nil") } - config := &DialerConfiguration{ - addrs: []*tlsAddrEntry{{"", 443}}, - } - for _, opt := range options { - if opt != nil { - if err := opt(config); err != nil { - return nil, err - } - } - } - return &tlsFragDialer{base, frag, config}, nil -} - -// WithTLSHostPortList tells the [transport.StreamDialer] which connections to treat as TLS. Only connections matching -// entries in the tlsAddrs list will be treated as TLS traffic and fragmented accordingly. -// -// Each entry in the tlsAddrs list should be in the format "host:port", where "host" can be an IP address or a domain -// name, and "port" must be a valid port number. You can use empty string "" as the "host" to only match based on the -// port, and "0" as the "port" to match any port. -// -// The default list only includes ":443", meaning all traffic on port 443 is treated as TLS. This function overrides -// the entire list. So if you want to add entries, you need to include ":443" along with your additional entries. -// -// Matching for "host" is case-insensitive and strict. For example, "google.com:123" will only match "google.com" and -// not "www.google.com". Subdomain wildcards are not supported. -func WithTLSHostPortList(tlsAddrs []string) DialerConfigurer { - return func(c *DialerConfiguration) error { - addrs := make([]*tlsAddrEntry, 0, len(tlsAddrs)) - for _, hostport := range tlsAddrs { - addr, err := parseTLSAddrEntry(hostport) - if err != nil { - return err - } - addrs = append(addrs, addr) - } - c.addrs = addrs - return nil - } + return &tlsFragDialer{base, frag}, nil } // Dial implements [transport.StreamConn].Dial. It establishes a connection to raddr in the format "host-or-ip:port". -// -// If raddr matches an entry in the valid TLS address list (which can be configured using [WithTLSHostPortList]), the -// initial TLS Client Hello record sent through the connection will be fragmented. -// -// If raddr is not listed in the valid TLS address list, the function simply utilizes the underlying base dialer's Dial -// function to establish the connection without any fragmentation. +// The initial TLS Client Hello record sent through the connection will be fragmented. func (d *tlsFragDialer) Dial(ctx context.Context, raddr string) (conn transport.StreamConn, err error) { conn, err = d.dialer.Dial(ctx, raddr) if err != nil { return } - for _, addr := range d.config.addrs { - if addr.matches(raddr) { - return WrapConnFunc(conn, d.frag) - } - } - return + return WrapConnFunc(conn, d.frag) } // WrapConnFunc wraps the base [transport.StreamConn] and splits the first TLS Client Hello packet into two records @@ -146,49 +77,3 @@ func WrapConnFunc(base transport.StreamConn, frag FragFunc) (transport.StreamCon } return transport.WrapConn(base, base, w), nil } - -// tlsAddrEntry reprsents an entry of the TLS traffic list. See [WithTLSHostPortList]. -type tlsAddrEntry struct { - host string - port int -} - -// parseTLSAddrEntry parses hostport in format "host:port" and returns the corresponding tlsAddrEntry. -func parseTLSAddrEntry(hostport string) (*tlsAddrEntry, error) { - host, portStr, err := net.SplitHostPort(hostport) - if err != nil { - return nil, err - } - port, err := strconv.Atoi(portStr) - if err != nil { - return nil, err - } - if port < 0 || port > 65535 { - return nil, fmt.Errorf("port must be within 0-65535: %w", strconv.ErrRange) - } - return &tlsAddrEntry{host, port}, nil -} - -// matches returns whether raddr matches this entry. -func (e *tlsAddrEntry) matches(raddr string) bool { - if len(e.host) == 0 && e.port == 0 { - return true - } - host, portStr, err := net.SplitHostPort(raddr) - if err != nil { - return false - } - if len(e.host) > 0 && !strings.EqualFold(e.host, host) { - return false - } - if e.port > 0 { - port, err := strconv.Atoi(portStr) - if err != nil { - return false - } - if e.port != port { - return false - } - } - return true -} diff --git a/transport/tlsfrag/stream_dialer_test.go b/transport/tlsfrag/stream_dialer_test.go index eadccbec..8065fe06 100644 --- a/transport/tlsfrag/stream_dialer_test.go +++ b/transport/tlsfrag/stream_dialer_test.go @@ -13,190 +13,3 @@ // limitations under the License. package tlsfrag - -import ( - "context" - "strconv" - "testing" - - "github.com/Jigsaw-Code/outline-sdk/transport" - "github.com/stretchr/testify/require" -) - -// this is the local conn that can be shared across tests -var theLocalConn = &localConn{} - -// Make sure NewStreamDialer returns error on invalid WithTLSHostPortList calls. -func TestNewStreamDialerWithInvalidTLSAddr(t *testing.T) { - cases := []struct { - addr string - errType error // nil indicates general error - }{ - {"1.2.3.4", nil}, - {":::::", nil}, - {"1.2.3.4:654-321", strconv.ErrSyntax}, - {"1.2.3.4:--8080", strconv.ErrSyntax}, - {"[::]:10000000000", strconv.ErrRange}, - {"1.2.3.4:-1234", strconv.ErrRange}, - {":654321", strconv.ErrRange}, - } - for _, tc := range cases { - d, err := NewStreamDialerFunc(localConnDialer{}, func([]byte) int { return 0 }, WithTLSHostPortList([]string{tc.addr})) - require.Error(t, err, tc.addr) - if tc.errType != nil { - require.ErrorIs(t, err, tc.errType, tc.addr) - } - require.Nil(t, d) - } -} - -// Make sure no fragmentation connection is created if raddr is not in the allowed list. -func TestDialFragmentOnTLSAddrOnly(t *testing.T) { - tlsAddrs := []string{ - ":443", // default entry - ":990", // additional FTPS port - ":853", // additional DNS-over-TLS port - "pop.gmail.com:995", // Gmail pop3 - } - cases := []struct { - msg string - raddrs []string - shouldFrag bool - shouldFragWithList bool - }{ - { - msg: "*:443 should be fragmented, raddr = %s", - raddrs: []string{"example.com:443", "66.77.88.99:443", "[2001:db8::1]:443"}, - shouldFrag: true, - shouldFragWithList: true, - }, - { - msg: "*:990 should be fragmented by allowlist, raddr = %s", - raddrs: []string{"my-test.org:990", "192.168.1.10:990", "[2001:db8:3333:4444:5555:6666:7777:8888]:990"}, - shouldFrag: false, - shouldFragWithList: true, - }, - { - msg: "*:8080 should not be fragmented, raddr = %s", - raddrs: []string{"google.com:8080", "64.233.191.255:8080", "[2001:db8:3333:4444:5555:6666:7777:8888]:8080"}, - shouldFrag: false, - shouldFragWithList: false, - }, - { - msg: "DNS ports should not be fragmented, raddr = %s", - raddrs: []string{"8.8.8.8:53", "8.8.4.4:53", "2001:4860:4860::8888", "2001:4860:4860::8844"}, - shouldFrag: false, - shouldFragWithList: false, - }, - { - msg: "DNS over TLS ports should be fragmented by allowlist, raddr = %s", - raddrs: []string{"9.9.9.9:853", "8.8.4.4:853", "[2001:4860:4860::8844]:853", "[2620:fe::fe]:853"}, - shouldFrag: false, - shouldFragWithList: true, - }, - { - msg: "only gmail POP3 should be fragmented by allowlist, raddr = %s", - raddrs: []string{"pop.GMail.com:995"}, - shouldFrag: false, - shouldFragWithList: true, - }, - { - msg: "non-gmail POP3 should not be fragmented, raddr = %s", - raddrs: []string{"8.8.8.8:995", "outlook.office365.com:995", "outlook.office365.com:993", "pop.gmail.com:993"}, - shouldFrag: false, - shouldFragWithList: false, - }, - } - - base := localConnDialer{} - assertShouldFrag := func(conn transport.StreamConn, msg, addr string) { - prevWrCnt := theLocalConn.writeCount - // this Write should not be pushed to theLocalConn yet because it's a valid TLS handshake - conn.Write([]byte{22}) - - nonFragConn, ok := conn.(*localConn) - require.False(t, ok, msg, addr) - require.Nil(t, nonFragConn, msg) - require.Equal(t, prevWrCnt, theLocalConn.writeCount, msg, addr) - } - assertShouldNotFrag := func(conn transport.StreamConn, msg, addr string) { - prevWrCnt := theLocalConn.writeCount - // this Write should be pushed to theLocalConn because it's a direct Write call - conn.Write([]byte{22}) - - nonFragConn, ok := conn.(*localConn) - require.True(t, ok, msg, addr) - require.NotNil(t, nonFragConn, msg, addr) - require.Equal(t, theLocalConn, nonFragConn) - require.Equal(t, prevWrCnt+1, theLocalConn.writeCount, msg, addr) - } - - // default dialer - d1, err := NewStreamDialerFunc(base, func([]byte) int { return 0 }) - require.NoError(t, err) - require.NotNil(t, d1) - - // with additional tls addrs - d2, err := NewStreamDialerFunc(base, func([]byte) int { return 0 }, WithTLSHostPortList(tlsAddrs)) - require.NoError(t, err) - require.NotNil(t, d2) - - // with no tls addrs - d3, err := NewStreamDialerFunc(base, func([]byte) int { return 0 }, WithTLSHostPortList([]string{})) - require.NoError(t, err) - require.NotNil(t, d3) - - // all traffic - d4, err := NewStreamDialerFunc(base, func([]byte) int { return 0 }, WithTLSHostPortList([]string{":0"})) - require.NoError(t, err) - require.NotNil(t, d4) - - for _, tc := range cases { - for _, addr := range tc.raddrs { - conn, err := d1.Dial(context.Background(), addr) - require.NoError(t, err, tc.msg, addr) - require.NotNil(t, conn, tc.msg, addr) - if tc.shouldFrag { - assertShouldFrag(conn, tc.msg, addr) - } else { - assertShouldNotFrag(conn, tc.msg, addr) - } - - conn, err = d2.Dial(context.Background(), addr) - require.NoError(t, err, tc.msg, addr) - require.NotNil(t, conn, tc.msg, addr) - if tc.shouldFragWithList { - assertShouldFrag(conn, tc.msg, addr) - } else { - assertShouldNotFrag(conn, tc.msg, addr) - } - - conn, err = d3.Dial(context.Background(), addr) - require.NoError(t, err, tc.msg, addr) - require.NotNil(t, conn, tc.msg, addr) - assertShouldNotFrag(conn, tc.msg, addr) - - conn, err = d4.Dial(context.Background(), addr) - require.NoError(t, err, tc.msg, addr) - require.NotNil(t, conn, tc.msg, addr) - assertShouldFrag(conn, tc.msg, addr) - } - } -} - -// testing utilitites - -type localConnDialer struct{} -type localConn struct { - transport.StreamConn - writeCount int -} - -func (localConnDialer) Dial(ctx context.Context, raddr string) (transport.StreamConn, error) { - return theLocalConn, nil -} - -func (lc *localConn) Write(b []byte) (n int, err error) { - lc.writeCount++ - return len(b), nil -} diff --git a/transport/tlsfrag/tls.go b/transport/tlsfrag/tls.go index df915c89..3578c324 100644 --- a/transport/tlsfrag/tls.go +++ b/transport/tlsfrag/tls.go @@ -20,7 +20,7 @@ import ( // This file contains helper functions and constants for TLS Client Hello message. -type recordType uint8 +type recordType byte type tlsVersion uint16 // TLS record layout from [RFC 8446]: @@ -46,9 +46,6 @@ type tlsVersion uint16 // // [RFC 8446]: https://datatracker.ietf.org/doc/html/rfc8446#section-5.1 const ( - tlsRecordWithTypeSize = 1 // the minimum size that contains record type - tlsRecordWithVersionHeaderSize = 3 // the minimum size that contains protocol version - recordHeaderLen = 5 maxMsgLen = 1 << 14 @@ -65,21 +62,11 @@ func getRecordType(hdr []byte) recordType { return recordType(hdr[0]) } -// putRecordType puts the TLS record type to the TLS header hdr[0]. This function will panic if len(hdr) < 1. -func putRecordType(hdr []byte, typ recordType) { - hdr[0] = byte(typ) -} - // getTLSVersion gets the TLS version from the TLS header hdr[1:3]. This function will panic if len(hdr) < 3. func getTLSVersion(hdr []byte) tlsVersion { return tlsVersion(binary.BigEndian.Uint16(hdr[1:])) } -// putTLSVersion puts the TLS version to the TLS header hdr[1:3]. This function will panic if len(hdr) < 3. -func putTLSVersion(hdr []byte, ver tlsVersion) { - binary.BigEndian.PutUint16(hdr[1:], uint16(ver)) -} - // getMsgLen gets the TLS message length from the TLS header hdr[3:5]. This function will panic if len(hdr) < 5. func getMsgLen(hdr []byte) uint16 { return binary.BigEndian.Uint16(hdr[3:]) @@ -106,11 +93,3 @@ func isValidTLSVersion(ver tlsVersion) bool { func isValidMsgLenForHandshake(len uint16) bool { return 0 < len && len <= maxMsgLen } - -// This function will panic if len(hdr) < 5. -func putTLSClientHelloHeader(hdr []byte, recordLen uint16) { - _ = hdr[recordHeaderLen-1] // bounds check to guarantee safety of writes below - putRecordType(hdr, recordTypeHandshake) - putTLSVersion(hdr, versionTLS10) - putMsgLen(hdr, recordLen) -} diff --git a/transport/tlsfrag/writer.go b/transport/tlsfrag/writer.go index 229fe840..62f77f62 100644 --- a/transport/tlsfrag/writer.go +++ b/transport/tlsfrag/writer.go @@ -145,13 +145,15 @@ func (w *clientHelloFragWriter) splitBufToRecords() { } header := make([]byte, recordHeaderLen) + copy(header, w.buf.Bytes()) + w.rcds = bytes.NewBuffer(make([]byte, 0, w.buf.Len()+recordHeaderLen)) - putTLSClientHelloHeader(header, uint16(split)) + putMsgLen(header, uint16(split)) w.rcds.Write(header) w.rcds.Write(content[:split]) - putTLSClientHelloHeader(header, uint16(len(content)-split)) + putMsgLen(header, uint16(len(content)-split)) w.rcds.Write(header) w.rcds.Write(content[split:]) From 4bc8f2c0925a0fada2def1bfdd9e9f1be520a222 Mon Sep 17 00:00:00 2001 From: jyyi1 Date: Thu, 16 Nov 2023 17:27:46 -0500 Subject: [PATCH 08/23] prevent writing empty buffers after split --- transport/tlsfrag/writer.go | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/transport/tlsfrag/writer.go b/transport/tlsfrag/writer.go index 62f77f62..70d370aa 100644 --- a/transport/tlsfrag/writer.go +++ b/transport/tlsfrag/writer.go @@ -87,9 +87,17 @@ func (w *clientHelloFragWriter) Write(p []byte) (n int, err error) { } else { w.copyBufToRecords() } - // recursively flush w.rcds and write the remaining content - m, e := w.Write(p[n:]) - return n + m, e + // We did not call w.Write(p[n:]) here because p[n:] might be empty, and we don't want to + // Write an empty buffer to w.base if it's not initiated by the upstream caller. + if _, err = w.flushRecords(); err != nil { + return + } + if p = p[n:]; len(p) > 0 { + m, e := w.base.Write(p) + n += m + err = e + } + return } if n < len(p) { From 5a687a0d72019ee9eb630d097a1b4c545179b240 Mon Sep 17 00:00:00 2001 From: jyyi1 Date: Thu, 16 Nov 2023 18:39:21 -0500 Subject: [PATCH 09/23] add test cases for stream dialer --- network/dnstruncate/packet_proxy_test.go | 18 +-- transport/tlsfrag/stream_dialer_test.go | 141 +++++++++++++++++++++++ 2 files changed, 150 insertions(+), 9 deletions(-) diff --git a/network/dnstruncate/packet_proxy_test.go b/network/dnstruncate/packet_proxy_test.go index 44d53b5d..cf2112b9 100644 --- a/network/dnstruncate/packet_proxy_test.go +++ b/network/dnstruncate/packet_proxy_test.go @@ -164,15 +164,15 @@ func constructDNSQuestionsFromDomainNames(questions []string) []layers.DNSQuesti // constructDNSRequestOrResponse creates the following DNS request/response: // -// [ `id` ]: 2 bytes -// [ Standard-Query/Response + Recursive ]: 0x01/0x81 -// [ Reserved/Response-No-Err ]: 0x00 -// [ Questions-Count ]: 2 bytes (= len(questions)) -// [ Answers Count ]: 2 bytes (= 0x00 0x00 / len(questions)) -// [ Authorities Count ]: 0x00 0x00 -// [ Resources Count ]: 0x00 0x01 -// [ `questions` ]: ? bytes -// [ Additional Resources ]: ? bytes (= OPT(payload_size=4096)) +// [ `id` ]: 2 bytes +// [ Standard-Query/Response + Recursive ]: 0x01/0x81 +// [ Reserved/Response-No-Err ]: 0x00 +// [ Questions-Count ]: 2 bytes (= len(questions)) +// [ Answers Count ]: 2 bytes (= 0x00 0x00 / len(questions)) +// [ Authorities Count ]: 0x00 0x00 +// [ Resources Count ]: 0x00 0x01 +// [ `questions` ]: ? bytes +// [ Additional Resources ]: ? bytes (= OPT(payload_size=4096)) // // https://datatracker.ietf.org/doc/html/rfc1035#section-4.1.1 // diff --git a/transport/tlsfrag/stream_dialer_test.go b/transport/tlsfrag/stream_dialer_test.go index 8065fe06..887f8f5c 100644 --- a/transport/tlsfrag/stream_dialer_test.go +++ b/transport/tlsfrag/stream_dialer_test.go @@ -13,3 +13,144 @@ // limitations under the License. package tlsfrag + +import ( + "context" + "errors" + "io" + "net" + "testing" + "time" + + "github.com/Jigsaw-Code/outline-sdk/transport" + "github.com/google/gopacket" + "github.com/google/gopacket/layers" + "github.com/stretchr/testify/require" +) + +// Make sure only the first Client Hello is splitted. +func TestStreamDialerFuncSplitsClientHello(t *testing.T) { + hello := constructTLSRecord(t, layers.TLSHandshake, 0x0301, []byte{0x01, 0x00, 0x00, 0x03, 0xaa, 0xbb, 0xcc}) + cipher := constructTLSRecord(t, layers.TLSChangeCipherSpec, 0x0303, []byte{0x01}) + req1 := constructTLSRecord(t, layers.TLSApplicationData, 0x0303, []byte{0xff, 0xee, 0xdd, 0xcc, 0xbb, 0xaa, 0x99, 0x88}) + + inner := &collectStreamDialer{} + conn := assertCanDialFragFunc(t, inner, "ipinfo.io:443", func(_ []byte) int { return 2 }) + defer conn.Close() + + assertCanWriteAll(t, conn, net.Buffers{hello, cipher, req1, hello, cipher, req1}) + + frag1 := constructTLSRecord(t, layers.TLSHandshake, 0x0301, []byte{0x01, 0x00}) + frag2 := constructTLSRecord(t, layers.TLSHandshake, 0x0301, []byte{0x00, 0x03, 0xaa, 0xbb, 0xcc}) + expected := net.Buffers{ + append(frag1, frag2...), // fragment 1 and fragment 2 will be merged in one single Write + cipher, req1, hello, cipher, req1, // unchanged + } + require.Equal(t, expected, inner.bufs) +} + +// Make sure we don't split if the first packet is not a Client Hello. +func TestStreamDialerFuncDontSplitNonClientHello(t *testing.T) { + cases := []struct { + msg string + pkt []byte + }{ + { + msg: "application data", + pkt: constructTLSRecord(t, layers.TLSApplicationData, 0x0303, []byte{0x01, 0x00, 0x00, 0x03, 0xdd, 0xee, 0xff}), + }, + { + msg: "cipher", + pkt: constructTLSRecord(t, layers.TLSChangeCipherSpec, 0x0303, []byte{0xff}), + }, + { + msg: "invalid version", + pkt: constructTLSRecord(t, layers.TLSHandshake, 0x0305, []byte{0x01, 0x00, 0x00, 0x03, 0xdd, 0xee, 0xff}), + }, + { + msg: "invalid length", + pkt: constructTLSRecord(t, layers.TLSHandshake, 0x0305, []byte{}), + }, + } + + cipher := constructTLSRecord(t, layers.TLSChangeCipherSpec, 0x0303, []byte{0x01}) + req := constructTLSRecord(t, layers.TLSApplicationData, 0x0303, []byte{0xff, 0xee, 0xdd, 0xcc, 0xbb, 0xaa, 0x99, 0x88}) + + for _, tc := range cases { + inner := &collectStreamDialer{} + conn := assertCanDialFragFunc(t, inner, "ipinfo.io:443", func(_ []byte) int { return 2 }) + defer conn.Close() + + assertCanWriteAll(t, conn, net.Buffers{tc.pkt, cipher, req}) + expected := net.Buffers{tc.pkt, cipher, req} + if len(tc.pkt) > 5 { + // header and content of the first pkt might be issued by two Writes, but they are not fragmented + expected = net.Buffers{tc.pkt[:5], tc.pkt[5:], cipher, req} + } + require.Equal(t, expected, inner.bufs, tc.msg) + } +} + +// test assertions + +func assertCanDialFragFunc(t *testing.T, inner transport.StreamDialer, raddr string, frag FragFunc) transport.StreamConn { + d, err := NewStreamDialerFunc(inner, frag) + require.NoError(t, err) + require.NotNil(t, d) + conn, err := d.Dial(context.Background(), raddr) + require.NoError(t, err) + require.NotNil(t, conn) + return conn +} + +func assertCanWriteAll(t *testing.T, w io.Writer, buf net.Buffers) { + for _, p := range buf { + n, err := w.Write(p) + require.NoError(t, err) + require.Equal(t, len(p), n) + } +} + +// private test helpers + +func constructTLSRecord(t *testing.T, typ layers.TLSType, ver layers.TLSVersion, payload []byte) []byte { + pkt := layers.TLS{ + AppData: []layers.TLSAppDataRecord{{ + TLSRecordHeader: layers.TLSRecordHeader{ + ContentType: typ, + Version: ver, + Length: uint16(len(payload)), + }, + Payload: payload, + }}, + } + + buf := gopacket.NewSerializeBuffer() + err := pkt.SerializeTo(buf, gopacket.SerializeOptions{}) + require.NoError(t, err) + return buf.Bytes() +} + +// collectStreamDialer collects all writes to this stream dialer and append it to bufs +type collectStreamDialer struct { + bufs net.Buffers +} + +func (d *collectStreamDialer) Dial(ctx context.Context, raddr string) (transport.StreamConn, error) { + return d, nil +} + +func (c *collectStreamDialer) Write(p []byte) (int, error) { + c.bufs = append(c.bufs, append([]byte{}, p...)) // copy p rather than retaining it according to the principle of Write + return len(p), nil +} + +func (c *collectStreamDialer) Read(p []byte) (int, error) { return 0, errors.New("not supported") } +func (c *collectStreamDialer) Close() error { return nil } +func (c *collectStreamDialer) CloseRead() error { return nil } +func (c *collectStreamDialer) CloseWrite() error { return nil } +func (c *collectStreamDialer) LocalAddr() net.Addr { return nil } +func (c *collectStreamDialer) RemoteAddr() net.Addr { return nil } +func (c *collectStreamDialer) SetDeadline(t time.Time) error { return errors.New("not supported") } +func (c *collectStreamDialer) SetReadDeadline(t time.Time) error { return errors.New("not supported") } +func (c *collectStreamDialer) SetWriteDeadline(t time.Time) error { return errors.New("not supported") } From ddf46f0f633d0efab70f3c1f601da30989c19fa3 Mon Sep 17 00:00:00 2001 From: "J. Yi" <93548144+jyyi1@users.noreply.github.com> Date: Fri, 17 Nov 2023 13:47:08 -0500 Subject: [PATCH 10/23] Update transport/tlsfrag/buffer.go Co-authored-by: Vinicius Fortuna --- transport/tlsfrag/buffer.go | 5 ----- 1 file changed, 5 deletions(-) diff --git a/transport/tlsfrag/buffer.go b/transport/tlsfrag/buffer.go index e7184bb5..35113592 100644 --- a/transport/tlsfrag/buffer.go +++ b/transport/tlsfrag/buffer.go @@ -51,11 +51,6 @@ func newClientHelloBuffer() *clientHelloBuffer { } } -// Len returns the length of this buffer including both the 5 bytes header and the content. -func (b *clientHelloBuffer) Len() int { - return b.len -} - // Bytes returns the full Client Hello packet including both the 5 bytes header and the content. func (b *clientHelloBuffer) Bytes() []byte { return b.data[:b.len] From 4f00ec880eef6229a3602b2086bbb1dbbc38d08f Mon Sep 17 00:00:00 2001 From: "J. Yi" <93548144+jyyi1@users.noreply.github.com> Date: Fri, 17 Nov 2023 13:49:59 -0500 Subject: [PATCH 11/23] Update transport/tlsfrag/buffer.go Co-authored-by: Vinicius Fortuna --- transport/tlsfrag/buffer.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transport/tlsfrag/buffer.go b/transport/tlsfrag/buffer.go index 35113592..3f198d1a 100644 --- a/transport/tlsfrag/buffer.go +++ b/transport/tlsfrag/buffer.go @@ -45,7 +45,7 @@ var _ io.ReaderFrom = (*clientHelloBuffer)(nil) func newClientHelloBuffer() *clientHelloBuffer { // Allocate the 5 bytes header first, and then reallocate it to contain the entire packet later return &clientHelloBuffer{ - data: make([]byte, recordHeaderLen), + data: make([]byte, 0, recordHeaderLen), valid: true, toRead: recordHeaderLen, } From edb97bbb75bac8f3786229247874f712ba7dd883 Mon Sep 17 00:00:00 2001 From: "J. Yi" <93548144+jyyi1@users.noreply.github.com> Date: Fri, 17 Nov 2023 13:50:20 -0500 Subject: [PATCH 12/23] Update transport/tlsfrag/buffer.go Co-authored-by: Vinicius Fortuna --- transport/tlsfrag/buffer.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transport/tlsfrag/buffer.go b/transport/tlsfrag/buffer.go index 3f198d1a..fc06f2f4 100644 --- a/transport/tlsfrag/buffer.go +++ b/transport/tlsfrag/buffer.go @@ -53,7 +53,7 @@ func newClientHelloBuffer() *clientHelloBuffer { // Bytes returns the full Client Hello packet including both the 5 bytes header and the content. func (b *clientHelloBuffer) Bytes() []byte { - return b.data[:b.len] + return b.data } // Write appends p to the buffer and returns the number of bytes actually used. From 87c3592047ac02e10b99c43ef492b920019f77cc Mon Sep 17 00:00:00 2001 From: jyyi1 Date: Tue, 21 Nov 2023 17:27:25 -0500 Subject: [PATCH 13/23] refactor buffer to get rid of b.len and use cap instead --- transport/tlsfrag/buffer.go | 62 +++++-------- transport/tlsfrag/buffer_test.go | 6 +- transport/tlsfrag/writer.go | 148 +++++++++++++++---------------- 3 files changed, 98 insertions(+), 118 deletions(-) diff --git a/transport/tlsfrag/buffer.go b/transport/tlsfrag/buffer.go index fc06f2f4..7f780a49 100644 --- a/transport/tlsfrag/buffer.go +++ b/transport/tlsfrag/buffer.go @@ -15,6 +15,7 @@ package tlsfrag import ( + "bytes" "errors" "fmt" "io" @@ -32,10 +33,10 @@ var ( // clientHelloBuffer is a byte buffer used to receive and buffer a TLS Client Hello packet. type clientHelloBuffer struct { - data []byte // the buffer that hosts both header and content, len: 5 -> 5+len(content) - len int // the actual bytes that have been read into data - valid bool // indicate whether the content in data is a valid TLS Client Hello record - toRead int // the number of bytes to read next, e.g. 5 -> len(content) + data []byte // The buffer that hosts both header and content, cap: 5 -> 5+len(content)+padding + padding int // The unused additional padding allocated at the end of data, 0 -> 5 + valid bool // Indicates whether the content in data is a valid TLS Client Hello record + bufrd *bytes.Reader // A reader used to read from the slice passed to Write } var _ io.Writer = (*clientHelloBuffer)(nil) @@ -45,9 +46,10 @@ var _ io.ReaderFrom = (*clientHelloBuffer)(nil) func newClientHelloBuffer() *clientHelloBuffer { // Allocate the 5 bytes header first, and then reallocate it to contain the entire packet later return &clientHelloBuffer{ - data: make([]byte, 0, recordHeaderLen), - valid: true, - toRead: recordHeaderLen, + data: make([]byte, 0, recordHeaderLen), + padding: 0, + valid: true, + bufrd: bytes.NewReader(nil), // It will be Reset in Write } } @@ -61,31 +63,13 @@ func (b *clientHelloBuffer) Bytes() []byte { // If an invalid TLS Client Hello message is detected, it returns the error errInvalidTLSClientHello. // If all bytes in p have been used and the buffer still requires more data to build a complete TLS Client Hello // message, it returns (len(p), nil). -func (b *clientHelloBuffer) Write(p []byte) (n int, err error) { - if !b.valid { - return 0, errInvalidTLSClientHello +func (b *clientHelloBuffer) Write(p []byte) (int, error) { + b.bufrd.Reset(p) + n, err := b.ReadFrom(b.bufrd) + if err == nil && int(n) != len(p) { + err = io.ErrShortWrite } - - for b.len < len(b.data) && len(p) > 0 { - m := copy(b.data[b.len:], p) - n += m - b.len += m - p = p[m:] - - if b.len == recordHeaderLen { - if err = b.validateTLSClientHello(); err != nil { - return - } - buf := make([]byte, recordHeaderLen+getMsgLen(b.data)) - copy(buf, b.data) - b.data = buf - } - } - - if b.len == len(b.data) { - err = errTLSClientHelloFullyReceived - } - return + return int(n), err } // ReadFrom reads all the data from r and appends it to this buffer until a complete Client Hello packet has been @@ -102,13 +86,13 @@ func (b *clientHelloBuffer) ReadFrom(r io.Reader) (n int64, err error) { return 0, errInvalidTLSClientHello } - for b.len < len(b.data) && err == nil { - m, e := r.Read(b.data[b.len:]) + for len(b.data) < cap(b.data)-b.padding && err == nil { + m, e := r.Read(b.data[len(b.data) : cap(b.data)-b.padding]) + b.data = b.data[:len(b.data)+m] n += int64(m) - b.len += m err = e - if b.len == recordHeaderLen { + if len(b.data) == recordHeaderLen { if e := b.validateTLSClientHello(); e != nil { if err == io.EOF { err = nil @@ -116,16 +100,16 @@ func (b *clientHelloBuffer) ReadFrom(r io.Reader) (n int64, err error) { err = errors.Join(err, e) return } - buf := make([]byte, recordHeaderLen+getMsgLen(b.data)) - copy(buf, b.data) - b.data = buf + buf := make([]byte, 0, recordHeaderLen*2+getMsgLen(b.data)) + b.data = append(buf, b.data...) + b.padding = recordHeaderLen } } if err == io.EOF { err = nil } - if b.len == len(b.data) { + if len(b.data) == cap(b.data)-b.padding { err = errors.Join(err, errTLSClientHelloFullyReceived) } return diff --git a/transport/tlsfrag/buffer_test.go b/transport/tlsfrag/buffer_test.go index ec6e7994..eedeccf0 100644 --- a/transport/tlsfrag/buffer_test.go +++ b/transport/tlsfrag/buffer_test.go @@ -39,11 +39,10 @@ func TestWriteValidClientHello(t *testing.T) { require.Equal(t, tc.expectRemaining[k], pkt[n:], tc.msg+": pkt-%d", k) totalExpectedBytes = append(totalExpectedBytes, pkt[:n]...) - require.Equal(t, len(totalExpectedBytes), buf.Len(), tc.msg+": pkt-%d", k) require.Equal(t, totalExpectedBytes, buf.Bytes(), tc.msg+": pkt-%d", k) } - require.Equal(t, len(tc.expectTotalPkt), buf.Len(), tc.msg) require.Equal(t, tc.expectTotalPkt, buf.Bytes(), tc.msg) + require.Equal(t, len(tc.expectTotalPkt)+5, cap(buf.Bytes()), tc.msg) } } @@ -67,13 +66,12 @@ func TestReadFromValidClientHello(t *testing.T) { require.Equal(t, tc.expectRemaining[k], pkt[n:], tc.msg+": pkt-%d", k) totalExpectedBytes = append(totalExpectedBytes, pkt[:n]...) - require.Equal(t, len(totalExpectedBytes), buf.Len(), tc.msg+": pkt-%d", k) require.Equal(t, totalExpectedBytes, buf.Bytes(), tc.msg+": pkt-%d", k) require.Equal(t, len(tc.expectRemaining[k]), r.Len(), tc.msg+": pkt-%d", k) require.Equal(t, tc.expectRemaining[k], r.Bytes(), tc.msg+": pkt-%d", k) } - require.Equal(t, len(tc.expectTotalPkt), buf.Len(), tc.msg) require.Equal(t, tc.expectTotalPkt, buf.Bytes(), tc.msg) + require.Equal(t, len(tc.expectTotalPkt)+5, cap(buf.Bytes()), tc.msg) } } diff --git a/transport/tlsfrag/writer.go b/transport/tlsfrag/writer.go index 70d370aa..4e78d16d 100644 --- a/transport/tlsfrag/writer.go +++ b/transport/tlsfrag/writer.go @@ -25,11 +25,11 @@ import ( // are not modified and are directly transmitted through the base [io.Writer]. type clientHelloFragWriter struct { base io.Writer - done bool // indicates all splitted rcds have been already written to base + done bool // Indicates all splitted rcds have been already written to base frag FragFunc - buf *clientHelloBuffer // the buffer containing and parsing a TLS Client Hello record - rcds *bytes.Buffer // the buffer containing splitted records what will be written to base + helloBuf *clientHelloBuffer // The buffer containing and parsing a TLS Client Hello record + record *bytes.Buffer // The buffer containing splitted records what will be written to base } // clientHelloFragReaderFrom serves as an optimized version of clientHelloFragWriter when the base [io.Writer] also @@ -58,9 +58,9 @@ func newClientHelloFragWriter(base io.Writer, frag FragFunc) (io.Writer, error) return nil, errors.New("frag callback function must not be nil") } fw := &clientHelloFragWriter{ - base: base, - frag: frag, - buf: newClientHelloBuffer(), + base: base, + frag: frag, + helloBuf: newClientHelloBuffer(), } if rf, ok := base.(io.ReaderFrom); ok { return &clientHelloFragReaderFrom{fw, rf}, nil @@ -71,37 +71,30 @@ func newClientHelloFragWriter(base io.Writer, frag FragFunc) (io.Writer, error) // Write implements io.Writer.Write. It attempts to split the data received in the first one or more Write call(s) // into two TLS records if the data corresponds to a TLS Client Hello record. func (w *clientHelloFragWriter) Write(p []byte) (n int, err error) { - if w.done { - return w.base.Write(p) - } - if w.rcds != nil { - if _, err = w.flushRecords(); err != nil { - return - } - return w.base.Write(p) - } - - if n, err = w.buf.Write(p); err != nil { - if errors.Is(err, errTLSClientHelloFullyReceived) { - w.splitBufToRecords() - } else { - w.copyBufToRecords() + if !w.done { + // not yet splitted, append to the buffer + if w.record == nil { + if n, err = w.helloBuf.Write(p); err == nil { + // all written, but Client Hello is not fully received yet + return + } + p = p[n:] + if errors.Is(err, errTLSClientHelloFullyReceived) { + w.splitHelloBufToRecord() + } else { + w.copyHelloBufToRecord() + } } - // We did not call w.Write(p[n:]) here because p[n:] might be empty, and we don't want to - // Write an empty buffer to w.base if it's not initiated by the upstream caller. - if _, err = w.flushRecords(); err != nil { + // already splitted (but previous Writes might fail), try to flush all remaining w.record to w.base + if _, err = w.flushRecord(); err != nil { return } - if p = p[n:]; len(p) > 0 { - m, e := w.base.Write(p) - n += m - err = e - } - return } - if n < len(p) { - return n, io.ErrShortWrite + if len(p) > 0 { + m, e := w.base.Write(p) + n += m + err = e } return } @@ -114,65 +107,70 @@ func (w *clientHelloFragWriter) Write(p []byte) (n int, err error) { // // It returns the number of bytes read. Any error except EOF encountered during the read is also returned. func (w *clientHelloFragReaderFrom) ReadFrom(r io.Reader) (n int64, err error) { - if w.done { - return w.baseRF.ReadFrom(r) - } - if w.rcds != nil { - if _, err = w.flushRecords(); err != nil { + if !w.done { + // not yet splitted, append to the buffer + if w.record == nil { + if n, err = w.helloBuf.ReadFrom(r); err == nil { + // EOF, but Client Hello is not fully received yet + return + } + if errors.Is(err, errTLSClientHelloFullyReceived) { + w.splitHelloBufToRecord() + } else { + w.copyHelloBufToRecord() + } + } + // already splitted (but previous Writes might fail), try to flush all remaining w.record to w.base + if _, err = w.flushRecord(); err != nil { return } - return w.baseRF.ReadFrom(r) } - if n, err = w.buf.ReadFrom(r); err != nil { - if errors.Is(err, errTLSClientHelloFullyReceived) { - w.splitBufToRecords() - } else { - w.copyBufToRecords() - } - // recursively flush w.rcds and read the remaining content from r - m, e := w.ReadFrom(r) - return n + m, e - } + m, e := w.baseRF.ReadFrom(r) + n += m + err = e return } -// copyBuf copies w.buf into w.rcds. -func (w *clientHelloFragWriter) copyBufToRecords() { - w.rcds = bytes.NewBuffer(w.buf.Bytes()) - w.buf = nil // allows the GC to recycle the memory +// copyHelloBufToRecord copies w.helloBuf into w.record without allocations. +func (w *clientHelloFragWriter) copyHelloBufToRecord() { + w.record = bytes.NewBuffer(w.helloBuf.Bytes()) + w.helloBuf = nil // allows the GC to recycle the memory } -// splitBuf splits w.buf into two records and put them into w.rcds. -func (w *clientHelloFragWriter) splitBufToRecords() { - content := w.buf.Bytes()[recordHeaderLen:] +// splitHelloBufToRecord splits w.helloBuf into two records and put them into w.record without allocations. +func (w *clientHelloFragWriter) splitHelloBufToRecord() { + received := w.helloBuf.Bytes() + content := received[recordHeaderLen:] split := w.frag(content) if split <= 0 || split >= len(content) { - w.copyBufToRecords() + w.copyHelloBufToRecord() return } - header := make([]byte, recordHeaderLen) - copy(header, w.buf.Bytes()) - - w.rcds = bytes.NewBuffer(make([]byte, 0, w.buf.Len()+recordHeaderLen)) - - putMsgLen(header, uint16(split)) - w.rcds.Write(header) - w.rcds.Write(content[:split]) - - putMsgLen(header, uint16(len(content)-split)) - w.rcds.Write(header) - w.rcds.Write(content[split:]) - - w.buf = nil // allows the GC to recycle the memory + // received: | <== header (5) ==> | <== split ==> | <== len(content)-split ==> | ... cap with padding (5) ... | + // \ \ + // +-----------------+ +-----------------+ + // \ \ + // splitted: | <== header (5) ==> | <== split ==> | <== header2 (5) ==> | <== len(content)-split ==> | + splitted := received[:len(received)+recordHeaderLen] + hdr1 := splitted[:recordHeaderLen] + hdr2 := splitted[recordHeaderLen+split : recordHeaderLen*2+split] + recvContent2 := splitted[recordHeaderLen+split : len(received)] + content2 := splitted[recordHeaderLen*2+split:] + copy(content2, recvContent2) + copy(hdr2, hdr1) + putMsgLen(hdr1, uint16(split)) + putMsgLen(hdr2, uint16(len(content)-split)) + w.record = bytes.NewBuffer(splitted) + w.helloBuf = nil // allows the GC to recycle the memory } -// flushRecords writes all bytes from w.rcds to base. -func (w *clientHelloFragWriter) flushRecords() (int, error) { - n, err := io.Copy(w.base, w.rcds) - if w.rcds.Len() == 0 { - w.rcds = nil // allows the GC to recycle the memory +// flushRecord writes all bytes from w.record to base. +func (w *clientHelloFragWriter) flushRecord() (int, error) { + n, err := io.Copy(w.base, w.record) + if w.record.Len() == 0 { + w.record = nil // allows the GC to recycle the memory w.done = true } return int(n), err From cb8e9f3bbbe2ecb3276e6ca1bdb8d8adb82d2aa5 Mon Sep 17 00:00:00 2001 From: jyyi1 Date: Tue, 21 Nov 2023 19:11:06 -0500 Subject: [PATCH 14/23] refactored with tlsRecordHeader type --- transport/tlsfrag/buffer.go | 22 ++++++++------- transport/tlsfrag/stream_dialer.go | 5 ++-- transport/tlsfrag/tls.go | 45 +++++++++++++++++------------- transport/tlsfrag/writer.go | 8 +++--- 4 files changed, 45 insertions(+), 35 deletions(-) diff --git a/transport/tlsfrag/buffer.go b/transport/tlsfrag/buffer.go index 7f780a49..fec537b9 100644 --- a/transport/tlsfrag/buffer.go +++ b/transport/tlsfrag/buffer.go @@ -93,14 +93,15 @@ func (b *clientHelloBuffer) ReadFrom(r io.Reader) (n int64, err error) { err = e if len(b.data) == recordHeaderLen { - if e := b.validateTLSClientHello(); e != nil { + hdr, e := b.validateTLSClientHello() + if e != nil { if err == io.EOF { err = nil } err = errors.Join(err, e) return } - buf := make([]byte, 0, recordHeaderLen*2+getMsgLen(b.data)) + buf := make([]byte, 0, recordHeaderLen*2+hdr.PayloadLen()) b.data = append(buf, b.data...) b.padding = recordHeaderLen } @@ -115,18 +116,19 @@ func (b *clientHelloBuffer) ReadFrom(r io.Reader) (n int64, err error) { return } -func (b *clientHelloBuffer) validateTLSClientHello() error { - if typ := getRecordType(b.data); typ != recordTypeHandshake { +func (b *clientHelloBuffer) validateTLSClientHello() (tlsRecordHeader, error) { + hdr := tlsRecordHeaderFromRawBytes(b.data) + if typ := hdr.Type(); typ != recordTypeHandshake { b.valid = false - return fmt.Errorf("record type %d is not handshake: %w", typ, errInvalidTLSClientHello) + return hdr, fmt.Errorf("record type %d is not handshake: %w", typ, errInvalidTLSClientHello) } - if ver := getTLSVersion(b.data); !isValidTLSVersion(ver) { + if ver := hdr.LegacyVersion(); !isValidTLSVersion(ver) { b.valid = false - return fmt.Errorf("%#04x is not a valid TLS version: %w", ver, errInvalidTLSClientHello) + return hdr, fmt.Errorf("%#04x is not a valid TLS version: %w", ver, errInvalidTLSClientHello) } - if len := getMsgLen(b.data); !isValidMsgLenForHandshake(len) { + if len := hdr.PayloadLen(); !isValidPayloadLenForHandshake(len) { b.valid = false - return fmt.Errorf("message length %v out of range: %w", len, errInvalidTLSClientHello) + return hdr, fmt.Errorf("message length %v out of range: %w", len, errInvalidTLSClientHello) } - return nil + return hdr, nil } diff --git a/transport/tlsfrag/stream_dialer.go b/transport/tlsfrag/stream_dialer.go index 7c3e1563..28ae761f 100644 --- a/transport/tlsfrag/stream_dialer.go +++ b/transport/tlsfrag/stream_dialer.go @@ -68,8 +68,9 @@ func (d *tlsFragDialer) Dial(ctx context.Context, raddr string) (conn transport. } // WrapConnFunc wraps the base [transport.StreamConn] and splits the first TLS Client Hello packet into two records -// according to the frag function. Subsequent data is forwarded without modification. If the first packet isn't a valid -// Client Hello, WrapConnFunc simply forwards all data through transparently. +// according to the frag function. Subsequent data is forwarded without modification. The Write to the base +// [transport.StreamConn] will be buffered until we have the full initial Client Hello record. If the first packet +// isn't a valid Client Hello, WrapConnFunc simply forwards all data through transparently. func WrapConnFunc(base transport.StreamConn, frag FragFunc) (transport.StreamConn, error) { w, err := newClientHelloFragWriter(base, frag) if err != nil { diff --git a/transport/tlsfrag/tls.go b/transport/tlsfrag/tls.go index 3578c324..3d74a1c9 100644 --- a/transport/tlsfrag/tls.go +++ b/transport/tlsfrag/tls.go @@ -22,6 +22,7 @@ import ( type recordType byte type tlsVersion uint16 +type tlsRecordHeader []byte // TLS record layout from [RFC 8446]: // @@ -40,14 +41,14 @@ type tlsVersion uint16 // +-------------+ Message Length + 5 // // RecordType := invalid(0) | handshake(22) | application_data(23) | ... -// Protocol Version (deprecated) := 0x0301 ("TLS 1.0") | 0x0303 ("TLS 1.2" & "TLS 1.3") | 0x0302 ("TLS 1.1") +// LegacyRecordVersion := 0x0301 ("TLS 1.0") | 0x0302 ("TLS 1.1") | 0x0303 ("TLS 1.2") // 0 < Message Length (of handshake) ≤ 2^14 // 0 ≤ Message Length (of application_data) ≤ 2^14 // // [RFC 8446]: https://datatracker.ietf.org/doc/html/rfc8446#section-5.1 const ( - recordHeaderLen = 5 - maxMsgLen = 1 << 14 + recordHeaderLen = 5 + maxRecordPayloadLen = 1 << 14 recordTypeHandshake recordType = 22 @@ -57,27 +58,33 @@ const ( versionTLS13 tlsVersion = 0x0304 ) -// getRecordType gets the TLS record type from the TLS header hdr[0]. This function will panic if len(hdr) < 1. -func getRecordType(hdr []byte) recordType { - return recordType(hdr[0]) +// tlsRecordHeaderFromRawBytes creates a new tlsRecordHeader from raw. This function will panic if len(raw) < 5. +func tlsRecordHeaderFromRawBytes(raw []byte) tlsRecordHeader { + _ = raw[recordHeaderLen-1] // early panic on invalid data + return tlsRecordHeader(raw[:recordHeaderLen]) } -// getTLSVersion gets the TLS version from the TLS header hdr[1:3]. This function will panic if len(hdr) < 3. -func getTLSVersion(hdr []byte) tlsVersion { - return tlsVersion(binary.BigEndian.Uint16(hdr[1:])) +// Type gets the TLS record type from the TLS header h[0]. This function will panic if len(h) < 1. +func (h tlsRecordHeader) Type() recordType { + return recordType(h[0]) } -// getMsgLen gets the TLS message length from the TLS header hdr[3:5]. This function will panic if len(hdr) < 5. -func getMsgLen(hdr []byte) uint16 { - return binary.BigEndian.Uint16(hdr[3:]) +// LegacyVersion gets the TLS version from the TLS header h[1:3]. This function will panic if len(h) < 3. +func (h tlsRecordHeader) LegacyVersion() tlsVersion { + return tlsVersion(binary.BigEndian.Uint16(h[1:3])) } -// putMsgLen puts the TLS message length to the TLS header hdr[3:5]. This function will panic if len(hdr) < 5. -func putMsgLen(hdr []byte, len uint16) { - binary.BigEndian.PutUint16(hdr[3:], len) +// PayloadLen gets the TLS record payload length from the TLS header h[3:5]. This function will panic if len(h) < 5. +func (h tlsRecordHeader) PayloadLen() uint16 { + return binary.BigEndian.Uint16(h[3:5]) } -// isValidTLSProtocolVersion determines whether ver is a valid TLS version according to RFC: +// SetPayloadLen puts the TLS record payload len to the TLS header h[3:5]. This function will panic if len(h) < 5. +func (h tlsRecordHeader) SetPayloadLen(len uint16) { + binary.BigEndian.PutUint16(h[3:5], len) +} + +// isValidTLSVersion determines whether ver is a valid TLS version according to RFC: // // """ // legacy_record_version: @@ -89,7 +96,7 @@ func isValidTLSVersion(ver tlsVersion) bool { return ver == versionTLS10 || ver == versionTLS11 || ver == versionTLS12 || ver == versionTLS13 } -// isValidRecordLenForHandshake checks whether 0 < len ≤ 2^14. -func isValidMsgLenForHandshake(len uint16) bool { - return 0 < len && len <= maxMsgLen +// isValidPayloadLenForHandshake checks whether 0 < len ≤ 2^14. +func isValidPayloadLenForHandshake(len uint16) bool { + return 0 < len && len <= maxRecordPayloadLen } diff --git a/transport/tlsfrag/writer.go b/transport/tlsfrag/writer.go index 4e78d16d..fbe9a334 100644 --- a/transport/tlsfrag/writer.go +++ b/transport/tlsfrag/writer.go @@ -154,14 +154,14 @@ func (w *clientHelloFragWriter) splitHelloBufToRecord() { // \ \ // splitted: | <== header (5) ==> | <== split ==> | <== header2 (5) ==> | <== len(content)-split ==> | splitted := received[:len(received)+recordHeaderLen] - hdr1 := splitted[:recordHeaderLen] - hdr2 := splitted[recordHeaderLen+split : recordHeaderLen*2+split] + hdr1 := tlsRecordHeaderFromRawBytes(splitted[:recordHeaderLen]) + hdr2 := tlsRecordHeaderFromRawBytes(splitted[recordHeaderLen+split : recordHeaderLen*2+split]) recvContent2 := splitted[recordHeaderLen+split : len(received)] content2 := splitted[recordHeaderLen*2+split:] copy(content2, recvContent2) copy(hdr2, hdr1) - putMsgLen(hdr1, uint16(split)) - putMsgLen(hdr2, uint16(len(content)-split)) + hdr1.SetPayloadLen(uint16(split)) + hdr2.SetPayloadLen(uint16(len(content) - split)) w.record = bytes.NewBuffer(splitted) w.helloBuf = nil // allows the GC to recycle the memory } From 8c0063dc1f3d6b706405c75d78ae269a1d013731 Mon Sep 17 00:00:00 2001 From: jyyi1 Date: Tue, 21 Nov 2023 19:23:06 -0500 Subject: [PATCH 15/23] try to reolve "Incorrect conversion between integer types" --- transport/socks5/socks5.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transport/socks5/socks5.go b/transport/socks5/socks5.go index 3e9f3ad2..a0ca5e44 100644 --- a/transport/socks5/socks5.go +++ b/transport/socks5/socks5.go @@ -79,7 +79,7 @@ func appendSOCKS5Address(b []byte, address string) ([]byte, error) { if err != nil { return nil, err } - portNum, err := strconv.Atoi(portStr) + portNum, err := strconv.ParseUint(portStr, 10, 16) if err != nil { return nil, err } From 25983694f049a2f68173a7da20a2c265a0c90564 Mon Sep 17 00:00:00 2001 From: jyyi1 Date: Tue, 21 Nov 2023 19:28:04 -0500 Subject: [PATCH 16/23] resolving werid error: https://github.com/Jigsaw-Code/outline-sdk/pull/133/checks?check_run_id=18912663629 --- transport/socks5/socks5.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/transport/socks5/socks5.go b/transport/socks5/socks5.go index a0ca5e44..2623cd1b 100644 --- a/transport/socks5/socks5.go +++ b/transport/socks5/socks5.go @@ -15,6 +15,7 @@ package socks5 import ( + "encoding/binary" "errors" "fmt" "net" @@ -109,6 +110,6 @@ func appendSOCKS5Address(b []byte, address string) ([]byte, error) { b = append(b, byte(len(host))) b = append(b, host...) } - b = append(b, byte(portNum>>8), byte(portNum)) + b = binary.BigEndian.AppendUint16(b, uint16(portNum)) return b, nil } From 5e6aa8773fe98526e8ad9a024242418fd5b6a4ff Mon Sep 17 00:00:00 2001 From: "J. Yi" <93548144+jyyi1@users.noreply.github.com> Date: Wed, 22 Nov 2023 15:11:42 -0500 Subject: [PATCH 17/23] Update transport/tlsfrag/writer.go Co-authored-by: Vinicius Fortuna --- transport/tlsfrag/writer.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transport/tlsfrag/writer.go b/transport/tlsfrag/writer.go index fbe9a334..c77cde8d 100644 --- a/transport/tlsfrag/writer.go +++ b/transport/tlsfrag/writer.go @@ -140,7 +140,7 @@ func (w *clientHelloFragWriter) copyHelloBufToRecord() { // splitHelloBufToRecord splits w.helloBuf into two records and put them into w.record without allocations. func (w *clientHelloFragWriter) splitHelloBufToRecord() { - received := w.helloBuf.Bytes() + originalRecord := w.helloBuf.Bytes() content := received[recordHeaderLen:] split := w.frag(content) if split <= 0 || split >= len(content) { From fb75111808ea7eb6e139b0d0e44479488fa73f37 Mon Sep 17 00:00:00 2001 From: "J. Yi" <93548144+jyyi1@users.noreply.github.com> Date: Wed, 22 Nov 2023 15:12:04 -0500 Subject: [PATCH 18/23] Update transport/tlsfrag/writer.go Co-authored-by: Vinicius Fortuna --- transport/tlsfrag/writer.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transport/tlsfrag/writer.go b/transport/tlsfrag/writer.go index c77cde8d..a217eea7 100644 --- a/transport/tlsfrag/writer.go +++ b/transport/tlsfrag/writer.go @@ -142,7 +142,7 @@ func (w *clientHelloFragWriter) copyHelloBufToRecord() { func (w *clientHelloFragWriter) splitHelloBufToRecord() { originalRecord := w.helloBuf.Bytes() content := received[recordHeaderLen:] - split := w.frag(content) + headLen := w.frag(content) if split <= 0 || split >= len(content) { w.copyHelloBufToRecord() return From c37113971483c77278ae4f999e441315e329edbb Mon Sep 17 00:00:00 2001 From: "J. Yi" <93548144+jyyi1@users.noreply.github.com> Date: Wed, 22 Nov 2023 15:12:21 -0500 Subject: [PATCH 19/23] Update transport/tlsfrag/writer.go Co-authored-by: Vinicius Fortuna --- transport/tlsfrag/writer.go | 1 + 1 file changed, 1 insertion(+) diff --git a/transport/tlsfrag/writer.go b/transport/tlsfrag/writer.go index a217eea7..51922fbb 100644 --- a/transport/tlsfrag/writer.go +++ b/transport/tlsfrag/writer.go @@ -156,6 +156,7 @@ func (w *clientHelloFragWriter) splitHelloBufToRecord() { splitted := received[:len(received)+recordHeaderLen] hdr1 := tlsRecordHeaderFromRawBytes(splitted[:recordHeaderLen]) hdr2 := tlsRecordHeaderFromRawBytes(splitted[recordHeaderLen+split : recordHeaderLen*2+split]) + // Shift tail fragment to make space for record header. recvContent2 := splitted[recordHeaderLen+split : len(received)] content2 := splitted[recordHeaderLen*2+split:] copy(content2, recvContent2) From 03142697be96f5e59e5589e2cb9706c9b46c1cb5 Mon Sep 17 00:00:00 2001 From: "J. Yi" <93548144+jyyi1@users.noreply.github.com> Date: Wed, 22 Nov 2023 15:12:40 -0500 Subject: [PATCH 20/23] Update transport/tlsfrag/writer.go Co-authored-by: Vinicius Fortuna --- transport/tlsfrag/writer.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transport/tlsfrag/writer.go b/transport/tlsfrag/writer.go index 51922fbb..cfc807f7 100644 --- a/transport/tlsfrag/writer.go +++ b/transport/tlsfrag/writer.go @@ -160,8 +160,8 @@ func (w *clientHelloFragWriter) splitHelloBufToRecord() { recvContent2 := splitted[recordHeaderLen+split : len(received)] content2 := splitted[recordHeaderLen*2+split:] copy(content2, recvContent2) + // Insert header for second fragment. copy(hdr2, hdr1) - hdr1.SetPayloadLen(uint16(split)) hdr2.SetPayloadLen(uint16(len(content) - split)) w.record = bytes.NewBuffer(splitted) w.helloBuf = nil // allows the GC to recycle the memory From 19240a9c663bdcd08353081aceff72d12b90dd13 Mon Sep 17 00:00:00 2001 From: jyyi1 Date: Wed, 22 Nov 2023 18:14:24 -0500 Subject: [PATCH 21/23] resolve comment (round 2) --- transport/tlsfrag/buffer.go | 85 +++++++++++++++++++++---------------- transport/tlsfrag/writer.go | 48 ++++++++++++--------- 2 files changed, 78 insertions(+), 55 deletions(-) diff --git a/transport/tlsfrag/buffer.go b/transport/tlsfrag/buffer.go index fec537b9..460591ad 100644 --- a/transport/tlsfrag/buffer.go +++ b/transport/tlsfrag/buffer.go @@ -27,16 +27,17 @@ var ( errTLSClientHelloFullyReceived = errors.New("already received a complete TLS Client Hello packet") // errInvalidTLSClientHello is the error used when the data received is not a valid TLS Client Hello. - // Please use [errors.Is] to compare the returned err object with this instance. errInvalidTLSClientHello = errors.New("not a valid TLS Client Hello packet") ) // clientHelloBuffer is a byte buffer used to receive and buffer a TLS Client Hello packet. type clientHelloBuffer struct { - data []byte // The buffer that hosts both header and content, cap: 5 -> 5+len(content)+padding - padding int // The unused additional padding allocated at the end of data, 0 -> 5 - valid bool // Indicates whether the content in data is a valid TLS Client Hello record - bufrd *bytes.Reader // A reader used to read from the slice passed to Write + // The buffer that hosts both header and content, cap: 5 -> 5+len(content)+padding + data []byte + // Indicates whether the content in data is a valid TLS Client Hello record + validationErr error + // A reader used to read from the slice passed to Write + bufrd *bytes.Reader } var _ io.Writer = (*clientHelloBuffer)(nil) @@ -46,10 +47,9 @@ var _ io.ReaderFrom = (*clientHelloBuffer)(nil) func newClientHelloBuffer() *clientHelloBuffer { // Allocate the 5 bytes header first, and then reallocate it to contain the entire packet later return &clientHelloBuffer{ - data: make([]byte, 0, recordHeaderLen), - padding: 0, - valid: true, - bufrd: bytes.NewReader(nil), // It will be Reset in Write + data: make([]byte, 0, recordHeaderLen), + validationErr: nil, + bufrd: bytes.NewReader(nil), // It will be Reset in Write } } @@ -82,53 +82,66 @@ func (b *clientHelloBuffer) Write(p []byte) (int, error) { // // You can call ReadFrom multiple times if r doesn't provide enough data to build a complete Client Hello packet. func (b *clientHelloBuffer) ReadFrom(r io.Reader) (n int64, err error) { - if !b.valid { - return 0, errInvalidTLSClientHello - } - - for len(b.data) < cap(b.data)-b.padding && err == nil { - m, e := r.Read(b.data[len(b.data) : cap(b.data)-b.padding]) + // Waiting to finish the header of 5 bytes + for len(b.data) < recordHeaderLen { + m, e := r.Read(b.data[len(b.data):recordHeaderLen]) b.data = b.data[:len(b.data)+m] n += int64(m) - err = e if len(b.data) == recordHeaderLen { - hdr, e := b.validateTLSClientHello() - if e != nil { - if err == io.EOF { - err = nil - } - err = errors.Join(err, e) + hdr := tlsRecordHeader(b.data) + if err = validateTLSClientHello(hdr); err != nil { + b.validationErr = err return } buf := make([]byte, 0, recordHeaderLen*2+hdr.PayloadLen()) b.data = append(buf, b.data...) - b.padding = recordHeaderLen + } + if err = e; err != nil { + if err == io.EOF { + err = nil + } + return } } - if err == io.EOF { - err = nil + // If the buffer is already invalid, or has already fully received the record + if b.validationErr != nil { + err = b.validationErr + return } - if len(b.data) == cap(b.data)-b.padding { - err = errors.Join(err, errTLSClientHelloFullyReceived) + + // Waiting to finish the payload of cap(b.data)-5 bytes + for len(b.data) < cap(b.data)-recordHeaderLen { + m, e := r.Read(b.data[len(b.data) : cap(b.data)-recordHeaderLen]) + b.data = b.data[:len(b.data)+m] + n += int64(m) + + if len(b.data) == cap(b.data)-recordHeaderLen { + err = errTLSClientHelloFullyReceived + return + } + if err = e; err != nil { + if err == io.EOF { + err = nil + } + return + } } + + err = errTLSClientHelloFullyReceived return } -func (b *clientHelloBuffer) validateTLSClientHello() (tlsRecordHeader, error) { - hdr := tlsRecordHeaderFromRawBytes(b.data) +func validateTLSClientHello(hdr tlsRecordHeader) error { if typ := hdr.Type(); typ != recordTypeHandshake { - b.valid = false - return hdr, fmt.Errorf("record type %d is not handshake: %w", typ, errInvalidTLSClientHello) + return fmt.Errorf("record type %d is not handshake: %w", typ, errInvalidTLSClientHello) } if ver := hdr.LegacyVersion(); !isValidTLSVersion(ver) { - b.valid = false - return hdr, fmt.Errorf("%#04x is not a valid TLS version: %w", ver, errInvalidTLSClientHello) + return fmt.Errorf("%#04x is not a valid TLS version: %w", ver, errInvalidTLSClientHello) } if len := hdr.PayloadLen(); !isValidPayloadLenForHandshake(len) { - b.valid = false - return hdr, fmt.Errorf("message length %v out of range: %w", len, errInvalidTLSClientHello) + return fmt.Errorf("message length %v out of range: %w", len, errInvalidTLSClientHello) } - return hdr, nil + return nil } diff --git a/transport/tlsfrag/writer.go b/transport/tlsfrag/writer.go index cfc807f7..b48e32b5 100644 --- a/transport/tlsfrag/writer.go +++ b/transport/tlsfrag/writer.go @@ -25,11 +25,14 @@ import ( // are not modified and are directly transmitted through the base [io.Writer]. type clientHelloFragWriter struct { base io.Writer - done bool // Indicates all splitted rcds have been already written to base + // Indicates all splitted rcds have been already written to base + done bool frag FragFunc - helloBuf *clientHelloBuffer // The buffer containing and parsing a TLS Client Hello record - record *bytes.Buffer // The buffer containing splitted records what will be written to base + // The buffer containing and parsing a TLS Client Hello record + helloBuf *clientHelloBuffer + // The buffer containing splitted records what will be written to base + record *bytes.Buffer } // clientHelloFragReaderFrom serves as an optimized version of clientHelloFragWriter when the base [io.Writer] also @@ -140,29 +143,36 @@ func (w *clientHelloFragWriter) copyHelloBufToRecord() { // splitHelloBufToRecord splits w.helloBuf into two records and put them into w.record without allocations. func (w *clientHelloFragWriter) splitHelloBufToRecord() { - originalRecord := w.helloBuf.Bytes() - content := received[recordHeaderLen:] + original := w.helloBuf.Bytes() + content := original[recordHeaderLen:] headLen := w.frag(content) - if split <= 0 || split >= len(content) { + if headLen <= 0 || headLen >= len(content) { w.copyHelloBufToRecord() return } + tailLen := len(content) - headLen - // received: | <== header (5) ==> | <== split ==> | <== len(content)-split ==> | ... cap with padding (5) ... | - // \ \ - // +-----------------+ +-----------------+ - // \ \ - // splitted: | <== header (5) ==> | <== split ==> | <== header2 (5) ==> | <== len(content)-split ==> | - splitted := received[:len(received)+recordHeaderLen] + // | header | payload | cap==len+5 + // original: | <= (5) => | <= head => | <= tail => | <= (5) => | + // | |\ \ + // | | \-------\ \-------\ + // | | \ \ + // splitted: | <= (5) => | <= head => | <= (5) => | <= tail => | + // | header1 | payload1 | header2 | payload2 | + splitted := original[:len(original)+recordHeaderLen] hdr1 := tlsRecordHeaderFromRawBytes(splitted[:recordHeaderLen]) - hdr2 := tlsRecordHeaderFromRawBytes(splitted[recordHeaderLen+split : recordHeaderLen*2+split]) - // Shift tail fragment to make space for record header. - recvContent2 := splitted[recordHeaderLen+split : len(received)] - content2 := splitted[recordHeaderLen*2+split:] - copy(content2, recvContent2) - // Insert header for second fragment. + hdr1.SetPayloadLen(uint16(headLen)) + + // Shift tail fragment to make space for record header. + tail := splitted[recordHeaderLen+headLen : len(original)] + payload2 := splitted[recordHeaderLen*2+headLen:] + copy(payload2, tail) + + // Insert header for second fragment. + hdr2 := tlsRecordHeaderFromRawBytes(splitted[recordHeaderLen+headLen : recordHeaderLen*2+headLen]) copy(hdr2, hdr1) - hdr2.SetPayloadLen(uint16(len(content) - split)) + hdr2.SetPayloadLen(uint16(tailLen)) + w.record = bytes.NewBuffer(splitted) w.helloBuf = nil // allows the GC to recycle the memory } From 16d2d9baf6bcd58f49aa02b22a7b969aa3905a32 Mon Sep 17 00:00:00 2001 From: jyyi1 Date: Wed, 22 Nov 2023 18:17:56 -0500 Subject: [PATCH 22/23] wrong comment --- transport/tlsfrag/buffer.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transport/tlsfrag/buffer.go b/transport/tlsfrag/buffer.go index 460591ad..33a6b67c 100644 --- a/transport/tlsfrag/buffer.go +++ b/transport/tlsfrag/buffer.go @@ -105,7 +105,7 @@ func (b *clientHelloBuffer) ReadFrom(r io.Reader) (n int64, err error) { } } - // If the buffer is already invalid, or has already fully received the record + // If the buffer is already invalid if b.validationErr != nil { err = b.validationErr return From 711162dbe216d1c05bd65bd2eb422f5d8dc9fae7 Mon Sep 17 00:00:00 2001 From: jyyi1 Date: Wed, 22 Nov 2023 18:23:04 -0500 Subject: [PATCH 23/23] simplify getting tail from original record --- transport/tlsfrag/writer.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transport/tlsfrag/writer.go b/transport/tlsfrag/writer.go index b48e32b5..520f8379 100644 --- a/transport/tlsfrag/writer.go +++ b/transport/tlsfrag/writer.go @@ -164,7 +164,7 @@ func (w *clientHelloFragWriter) splitHelloBufToRecord() { hdr1.SetPayloadLen(uint16(headLen)) // Shift tail fragment to make space for record header. - tail := splitted[recordHeaderLen+headLen : len(original)] + tail := original[recordHeaderLen+headLen:] payload2 := splitted[recordHeaderLen*2+headLen:] copy(payload2, tail)