From 107fc18a234d166dfd087606539b1bb6a2d3389e Mon Sep 17 00:00:00 2001 From: Guangming Li Date: Wed, 27 Jan 2021 11:54:56 +0800 Subject: [PATCH] fix obfs-tls with shadowsocks aead issue #695 --- obfs.go | 157 ++++++++++++++++++++++++++++++++++++++++++++------------ 1 file changed, 123 insertions(+), 34 deletions(-) diff --git a/obfs.go b/obfs.go index 4e161d3e..44546841 100644 --- a/obfs.go +++ b/obfs.go @@ -313,8 +313,31 @@ var ( 0x0601, 0x0602, 0x0603, 0x0501, 0x0502, 0x0503, 0x0401, 0x0402, 0x0403, 0x0301, 0x0302, 0x0303, 0x0201, 0x0202, 0x0203, } + + tlsRecordTypes = []uint8{0x16, 0x14, 0x16, 0x17} + tlsVersionMinors = []uint8{0x01, 0x03, 0x03, 0x03} + + ErrBadType = errors.New("bad type") + ErrBadMajorVersion = errors.New("bad major version") + ErrBadMinorVersion = errors.New("bad minor version") + ErrMaxDataLen = errors.New("bad tls data len") +) + +const ( + tlsRecordStateType = iota + tlsRecordStateVersion0 + tlsRecordStateVersion1 + tlsRecordStateLength0 + tlsRecordStateLength1 + tlsRecordStateData ) +type obfsTLSParser struct { + step uint8 + state uint8 + length uint16 +} + type obfsTLSConn struct { net.Conn rbuf bytes.Buffer @@ -322,15 +345,96 @@ type obfsTLSConn struct { host string isServer bool handshaked chan struct{} + parser *obfsTLSParser handshakeMutex sync.Mutex } +func (r *obfsTLSParser) Parse(b []byte) (int, error) { + i := 0 + last := 0 + length := len(b) + + for i < length { + ch := b[i] + switch r.state { + case tlsRecordStateType: + if tlsRecordTypes[r.step] != ch { + return 0, ErrBadType + } + r.state = tlsRecordStateVersion0 + i++ + case tlsRecordStateVersion0: + if ch != 0x03 { + return 0, ErrBadMajorVersion + } + r.state = tlsRecordStateVersion1 + i++ + case tlsRecordStateVersion1: + if ch != tlsVersionMinors[r.step] { + return 0, ErrBadMinorVersion + } + r.state = tlsRecordStateLength0 + i++ + case tlsRecordStateLength0: + r.length = uint16(ch) << 8 + r.state = tlsRecordStateLength1 + i++ + case tlsRecordStateLength1: + r.length |= uint16(ch) + if r.step == 0 { + r.length = 91 + } else if r.step == 1 { + r.length = 1 + } else if r.length > maxTLSDataLen { + return 0, ErrMaxDataLen + } + if r.length > 0 { + r.state = tlsRecordStateData + } else { + r.state = tlsRecordStateType + r.step++ + } + i++ + case tlsRecordStateData: + left := uint16(length - i) + if left > r.length { + left = r.length + } + if r.step >= 2 { + skip := i - last + copy(b[last:], b[i:length]) + length -= int(skip) + last += int(left) + i = last + } else { + i += int(left) + } + r.length -= left + if r.length == 0 { + if r.step < 3 { + r.step++ + } + r.state = tlsRecordStateType + } + } + } + + if last == 0 { + return 0, nil + } else if last < length { + length -= last + } + + return length, nil +} + // ClientObfsTLSConn creates a connection for obfs-tls client. func ClientObfsTLSConn(conn net.Conn, host string) net.Conn { return &obfsTLSConn{ Conn: conn, host: host, handshaked: make(chan struct{}), + parser: &obfsTLSParser{}, } } @@ -416,32 +520,6 @@ func (c *obfsTLSConn) clientHandshake(payload []byte) error { if _, err := record.WriteTo(c.Conn); err != nil { return err } - - // server hello handshake message - if _, err := record.ReadFrom(c.Conn); err != nil { - return err - } - if record.Type != dissector.Handshake { - return dissector.ErrBadType - } - - // change cipher spec message - if _, err := record.ReadFrom(c.Conn); err != nil { - return err - } - if record.Type != dissector.ChangeCipherSpec { - return dissector.ErrBadType - } - - // encrypted handshake message - if _, err := record.ReadFrom(c.Conn); err != nil { - return err - } - if record.Type != dissector.Handshake { - return dissector.ErrBadType - } - - _, err = c.rbuf.Write(record.Opaque) return err } @@ -521,19 +599,30 @@ func (c *obfsTLSConn) Read(b []byte) (n int, err error) { return } } + select { case <-c.handshaked: } - if c.rbuf.Len() > 0 { - return c.rbuf.Read(b) - } - record := &dissector.Record{} - if _, err = record.ReadFrom(c.Conn); err != nil { - return + if c.isServer { + if c.rbuf.Len() > 0 { + return c.rbuf.Read(b) + } + record := &dissector.Record{} + if _, err = record.ReadFrom(c.Conn); err != nil { + return + } + n = copy(b, record.Opaque) + _, err = c.rbuf.Write(record.Opaque[n:]) + } else { + n, err = c.Conn.Read(b) + if err != nil { + return + } + if n > 0 { + n, err = c.parser.Parse(b[:n]) + } } - n = copy(b, record.Opaque) - _, err = c.rbuf.Write(record.Opaque[n:]) return }