Skip to content

Commit

Permalink
Reverse port forwarding implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
jdhozdiego committed Jul 26, 2024
1 parent 20f2894 commit 5abff6d
Show file tree
Hide file tree
Showing 6 changed files with 410 additions and 1 deletion.
117 changes: 117 additions & 0 deletions channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,16 @@ type TCPForwardingChannelImpl struct {
Channel
}

type TCPReverseForwardingChannelImpl struct {
RemoteAddr *net.TCPAddr
LocalAddr *net.TCPAddr
Channel
}
type TCPOpenReverseForwardingChannelImpl struct {
RemoteAddr *net.TCPAddr
Channel
}

func buildHeader(conversationStreamID uint64, channelType string, maxPacketSize uint64, additionalBytes []byte) []byte {
channelTypeBuf := make([]byte, util.SSHStringLen(channelType))
util.WriteSSHString(channelTypeBuf, channelType)
Expand Down Expand Up @@ -159,6 +169,38 @@ func buildForwardingChannelAdditionalBytes(remoteAddr net.IP, port uint16) []byt
buf = append(buf, portBuf[:]...)
return buf
}
func buildRequestTCPReverseChannelAdditionalBytes(localAddr net.IP, localPort uint16, remoteAddr net.IP, remotePort uint16) []byte {
var buf []byte
var portBuf [2]byte
//var portBuf2 [2]byte

var addressFamily util.SSHForwardingAddressFamily
if len(localAddr) == 4 {
addressFamily = util.SSHAFIpv4
} else {
addressFamily = util.SSHAFIpv6
}

buf = util.AppendVarInt(buf, addressFamily)
buf = append(buf, localAddr...)
binary.BigEndian.PutUint16(portBuf[:], uint16(localPort))
buf = append(buf, portBuf[:]...)

if len(remoteAddr) == 4 {
addressFamily = util.SSHAFIpv4
} else {
addressFamily = util.SSHAFIpv6
}

buf = util.AppendVarInt(buf, addressFamily)
buf = append(buf, remoteAddr...)
binary.BigEndian.PutUint16(portBuf[:], uint16(remotePort))
buf = append(buf, portBuf[:]...)
//TODO: If I do not duplicate this, the port does not arrive to destination
buf = append(buf, portBuf[:]...)
return buf

}

func parseHeader(channelID uint64, r util.Reader) (conversationControlStreamID ControlStreamID, channelType string, maxPacketSize uint64, err error) {
conversationControlStreamID, err = util.ReadVarInt(r)
Expand Down Expand Up @@ -206,6 +248,67 @@ func parseForwardingHeader(channelID uint64, buf util.Reader) (net.IP, uint16, e
return address, port, nil
}

func parseRequestReverseHeader(channelID uint64, buf util.Reader) (net.IP, uint16, net.IP, uint16, error) {

var localaddress net.IP
var remoteaddress net.IP
var portBuf [2]byte

//Parse local address and port where the reverse socket is proxied within the client machine
//------------------------------------------------------------------------------------------
addressFamily, err := util.ReadVarInt(buf)
if err != nil {
return nil, 0, nil, 0, err
}

if addressFamily == util.SSHAFIpv4 {
localaddress = make([]byte, 4)
} else if addressFamily == util.SSHAFIpv6 {
localaddress = make([]byte, 16)
} else {
return nil, 0, nil, 0, fmt.Errorf("invalid local address family: %d", addressFamily)
}

_, err = buf.Read(localaddress)
if err != nil {
return nil, 0, nil, 0, err
}

_, err = buf.Read(portBuf[:])
if err != nil {
return nil, 0, nil, 0, err
}
localport := binary.BigEndian.Uint16(portBuf[:])

//Parse remote address and port of the remote service to be proxied within the client machine
//-------------------------------------------------------------------------------------------
addressFamily, err = util.ReadVarInt(buf)
if err != nil {
return nil, 0, nil, 0, err
}

if addressFamily == util.SSHAFIpv4 {
remoteaddress = make([]byte, 4)
} else if addressFamily == util.SSHAFIpv6 {
remoteaddress = make([]byte, 16)
} else {
return nil, 0, nil, 0, fmt.Errorf("invalid remote address family: %d", addressFamily)
}

_, err = buf.Read(remoteaddress)
if err != nil {
return nil, 0, nil, 0, err
}

_, err = buf.Read(portBuf[:])
if err != nil {
return nil, 0, nil, 0, err
}
remoteport := binary.BigEndian.Uint16(portBuf[:])

return localaddress, localport, remoteaddress, remoteport, nil
}

func parseUDPForwardingHeader(channelID uint64, buf util.Reader) (*net.UDPAddr, error) {
address, port, err := parseForwardingHeader(channelID, buf)
if err != nil {
Expand All @@ -228,6 +331,20 @@ func parseTCPForwardingHeader(channelID uint64, buf util.Reader) (*net.TCPAddr,
}, nil
}

func parseTCPRequestReverseHeader(channelID uint64, buf util.Reader) (*net.TCPAddr, *net.TCPAddr, error) {
localaddress, localport, remoteaddress, remoteport, err := parseRequestReverseHeader(channelID, buf)
if err != nil {
return nil, nil, err
}
return &net.TCPAddr{
IP: localaddress,
Port: int(localport),
}, &net.TCPAddr{
IP: remoteaddress,
Port: int(remoteport),
}, nil
}

func NewChannel(conversationStreamID uint64, conversationID ConversationID, channelID uint64, channelType string, maxPacketSize uint64, recv quic.ReceiveStream,
send io.WriteCloser, datagramSender util.SSH3DatagramSenderFunc, channelCloseListener channelCloseListener, sendHeader bool, confirmSent bool,
confirmReceived bool, datagramsQueueSize uint64, additonalHeaderBytes []byte) Channel {
Expand Down
114 changes: 114 additions & 0 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,84 @@ func forwardTCPInBackground(ctx context.Context, channel ssh3.Channel, conn *net
}()
}

func forwardReverseTCPInBackground(ctx context.Context, channel ssh3.Channel, conn *net.TCPConn) {
go func() {
defer conn.CloseWrite()
for {
select {
case <-ctx.Done():
return
default:
}
genericMessage, err := channel.NextMessage()
if err == io.EOF {
log.Info().Msgf("eof on tcp-forwarding channel %d", channel.ChannelID())
} else if err != nil {
log.Error().Msgf("could get message from tcp forwarding channel: %s", err)
return
}

// nothing to process
if genericMessage == nil {
return
}

switch message := genericMessage.(type) {
case *ssh3Messages.DataOrExtendedDataMessage:
if message.DataType == ssh3Messages.SSH_EXTENDED_DATA_NONE {
_, err := conn.Write([]byte(message.Data))
if err != nil {
log.Error().Msgf("could not write data on TCP socket: %s", err)
// signal the write error to the peer
channel.CancelRead()
return
}
} else {
log.Warn().Msgf("ignoring message data of unexpected type %d on TCP forwarding channel %d", message.DataType, channel.ChannelID())
}
default:
log.Warn().Msgf("ignoring message of type %T on TCP forwarding channel %d", message, channel.ChannelID())
}
}
}()

go func() {
defer channel.Close()
defer conn.CloseRead()
buf := make([]byte, channel.MaxPacketSize())
for {
select {
case <-ctx.Done():
return
default:
}
n, err := conn.Read(buf)
if err != nil && err != io.EOF {
log.Error().Msgf("could read data on TCP socket: %s", err)
return
}
//log.Debug().Msgf("Reading from socket: %s", string(buf))
_, errWrite := channel.WriteData(buf[:n], ssh3Messages.SSH_EXTENDED_DATA_NONE)
if errWrite != nil {
switch quicErr := errWrite.(type) {
case *quic.StreamError:
if quicErr.Remote && quicErr.ErrorCode == 42 {
log.Info().Msgf("writing was canceled by the remote, closing the socket: %s", errWrite)
} else {
log.Error().Msgf("unhandled quic stream error: %+v", quicErr)
}
default:
log.Error().Msgf("could send data on channel: %s", errWrite)
}
return
}
if err == io.EOF {
return
}
}
}()
}

type Client struct {
qconn quic.EarlyConnection
*ssh3.Conversation
Expand Down Expand Up @@ -453,6 +531,42 @@ func (c *Client) ForwardTCP(ctx context.Context, localTCPAddr *net.TCPAddr, remo
return conn.Addr().(*net.TCPAddr), nil
}

func (c *Client) ReverseTCP(ctx context.Context, localTCPAddr *net.TCPAddr, remoteTCPAddr *net.TCPAddr) (*net.TCPAddr, error) {
log.Debug().Msgf("start TCP forwarding from %s to %s", localTCPAddr, remoteTCPAddr)

forwardingChannel, err := c.RequestTCPReverseChannel(30000, 10, localTCPAddr, remoteTCPAddr)
if err != nil {
log.Error().Msgf("could open new TCP reverse forwarding channel: %s", err)
return remoteTCPAddr, nil
}

go func() {
for {
channel, err := c.AcceptChannel(c.Context())
if err != nil {
log.Debug().Msgf("Error accepting channel")
}

switch channel.ChannelType() {
case "open-request-reverse-tcp":
log.Debug().Msgf("start reverse TCP forwarding from %s to %s", localTCPAddr, remoteTCPAddr)

conn, err := net.DialTCP("tcp", nil, remoteTCPAddr)
if err != nil {
return
}
forwardReverseTCPInBackground(ctx, channel, conn)
if err != nil {
channel.Close()
return
}
}
}
}()
forwardingChannel.Close()
return remoteTCPAddr, nil
}

func (c *Client) RunSession(tty *os.File, forwardSSHAgent bool, command ...string) error {

ctx := c.Context()
Expand Down
106 changes: 106 additions & 0 deletions cmd/ssh3-server.go
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,84 @@ func forwardTCPInBackground(ctx context.Context, channel ssh3.Channel, conn *net
}()
}

func forwardReverseTCPInBackground(ctx context.Context, channel ssh3.Channel, conn *net.TCPConn) {
go func() {
defer channel.Close()
defer conn.CloseRead()
buf := make([]byte, channel.MaxPacketSize())
for {
select {
case <-ctx.Done():
return
default:
}
n, err := conn.Read(buf)
if err != nil && err != io.EOF {
log.Error().Msgf("could read data on TCP socket: %s", err)
return
}
_, errWrite := channel.WriteData(buf[:n], ssh3Messages.SSH_EXTENDED_DATA_NONE)
if errWrite != nil {
switch quicErr := errWrite.(type) {
case *quic.StreamError:
if quicErr.Remote && quicErr.ErrorCode == 42 {
log.Info().Msgf("writing was canceled by the remote, closing the socket: %s", errWrite)
} else {
log.Error().Msgf("unhandled quic stream error: %+v", quicErr)
}
default:
log.Error().Msgf("could send data on channel: %s", errWrite)
}
return
}
if err == io.EOF {
return
}
}
}()

go func() {
defer conn.CloseWrite()
for {
select {
case <-ctx.Done():
return
default:
}
genericMessage, err := channel.NextMessage()
if err == io.EOF {
log.Info().Msgf("eof on reverse-tcp-forwarding channel %d", channel.ChannelID())
} else if err != nil {
log.Error().Msgf("could get message from tcp forwarding channel: %s", err)
return
}

// nothing to process
if genericMessage == nil {
return
}

switch message := genericMessage.(type) {
case *ssh3Messages.DataOrExtendedDataMessage:
if message.DataType == ssh3Messages.SSH_EXTENDED_DATA_NONE {
_, err := conn.Write([]byte(message.Data))
if err != nil {
log.Error().Msgf("could not write data on TCP socket: %s", err)
// signal the write error to the peer
channel.CancelRead()
return
}
} else {
log.Warn().Msgf("ignoring message data of unexpected type %d on TCP forwarding channel %d", message.DataType, channel.ChannelID())
}
default:
log.Warn().Msgf("ignoring message of type %T on TCP forwarding channel %d", message, channel.ChannelID())
}
}
}()

}

func execCmdInBackground(channel ssh3.Channel, openPty *openPty, user *unix_util.User, runningCommand *runningCommand, authAgentSocketPath string) error {
setupEnv(user, runningCommand, authAgentSocketPath)
if openPty != nil {
Expand Down Expand Up @@ -551,6 +629,32 @@ func handleTCPForwardingChannel(ctx context.Context, user *unix_util.User, conv
return nil
}

func handleTCPReverseForwardingChannel(ctx context.Context, user *unix_util.User, conv *ssh3.Conversation, channel *ssh3.TCPReverseForwardingChannelImpl) error {
conn, err := net.ListenTCP("tcp", channel.LocalAddr)
if err != nil {
log.Error().Msgf("could listen on TCP socket: %s", err)
return nil
}

go func() {
for {
conn, err := conn.AcceptTCP()
if err != nil {
log.Error().Msgf("could read on UDP socket: %s", err)
return
}

forwardingChannel, err := conv.OpenTCPReverseForwardingChannel(30000, 10, channel.RemoteAddr)
if err != nil {
log.Error().Msgf("could not open new TCP reverse forwarding channel: %s", err)
return
}
forwardReverseTCPInBackground(ctx, forwardingChannel, conn)
}
}()
return nil
}

func newDataReq(user *unix_util.User, channel ssh3.Channel, request ssh3Messages.DataOrExtendedDataMessage) error {
runningSession, ok := runningSessions.Get(channel)
if !ok {
Expand Down Expand Up @@ -855,6 +959,8 @@ func ServerMain() int {
handleUDPForwardingChannel(conv.Context(), authenticatedUser, conv, c)
case *ssh3.TCPForwardingChannelImpl:
handleTCPForwardingChannel(conv.Context(), authenticatedUser, conv, c)
case *ssh3.TCPReverseForwardingChannelImpl:
handleTCPReverseForwardingChannel(conv.Context(), authenticatedUser, conv, c)
default:
runningSessions.Insert(channel, &runningSession{
channelState: LARVAL,
Expand Down
Loading

0 comments on commit 5abff6d

Please sign in to comment.