Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cleanup: clean up TCP calls and use netip #179

Merged
merged 4 commits into from
May 13, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/go.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,4 @@ jobs:
run: go build -v ./...

- name: Test
run: go test -v -race -benchmem -bench=. ./... -benchtime=100ms
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By removing -v, we only see the output for failures, which is more helpful

run: go test -race -benchmem -bench=. ./... -benchtime=100ms
1 change: 1 addition & 0 deletions cmd/outline-ss-server/metrics_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ func BenchmarkCloseTCP(b *testing.B) {
duration := time.Minute
b.ResetTimer()
for i := 0; i < b.N; i++ {
ssMetrics.AddAuthenticatedTCPConnection(addr, accessKey)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without this we get a ton of warnings.

ssMetrics.AddClosedTCPConnection(ipinfo, addr, accessKey, status, data, duration)
ssMetrics.AddTCPCipherSearch(true, timeToCipher)
}
Expand Down
5 changes: 3 additions & 2 deletions internal/integration_test/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"fmt"
"io"
"net"
"net/netip"
"sync"
"testing"
"time"
Expand Down Expand Up @@ -107,7 +108,7 @@ func startUDPEchoServer(t testing.TB) (*net.UDPConn, *sync.WaitGroup) {
t.Logf("Failed to read from UDP conn: %v", err)
return
}
conn.WriteTo(buf[:n], clientAddr)
_, err = conn.WriteTo(buf[:n], clientAddr)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug fix

if err != nil {
t.Fatalf("Failed to write: %v", err)
}
Expand Down Expand Up @@ -335,7 +336,7 @@ func TestUDPEcho(t *testing.T) {
proxyConn.Close()
<-done
// Verify that the expected metrics were reported.
snapshot := cipherList.SnapshotForClientIP(nil)
snapshot := cipherList.SnapshotForClientIP(netip.Addr{})
keyID := snapshot[0].Value.(*service.CipherEntry).ID

if testMetrics.natAdded != 1 {
Expand Down
16 changes: 8 additions & 8 deletions service/cipher_list.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ package service

import (
"container/list"
"net"
"net/netip"
"sync"

"github.com/Jigsaw-Code/outline-sdk/transport/shadowsocks"
Expand All @@ -31,7 +31,7 @@ type CipherEntry struct {
ID string
CryptoKey *shadowsocks.EncryptionKey
SaltGenerator ServerSaltGenerator
lastClientIP net.IP
lastClientIP netip.Addr
}

// MakeCipherEntry constructs a CipherEntry.
Expand All @@ -56,8 +56,8 @@ func MakeCipherEntry(id string, cryptoKey *shadowsocks.EncryptionKey, secret str
// snapshotting and moving to front.
type CipherList interface {
// Returns a snapshot of the cipher list optimized for this client IP
SnapshotForClientIP(clientIP net.IP) []*list.Element
MarkUsedByClientIP(e *list.Element, clientIP net.IP)
SnapshotForClientIP(clientIP netip.Addr) []*list.Element
MarkUsedByClientIP(e *list.Element, clientIP netip.Addr)
// Update replaces the current contents of the CipherList with `contents`,
// which is a List of *CipherEntry. Update takes ownership of `contents`,
// which must not be read or written after this call.
Expand All @@ -75,12 +75,12 @@ func NewCipherList() CipherList {
return &cipherList{list: list.New()}
}

func matchesIP(e *list.Element, clientIP net.IP) bool {
func matchesIP(e *list.Element, clientIP netip.Addr) bool {
c := e.Value.(*CipherEntry)
return clientIP != nil && clientIP.Equal(c.lastClientIP)
return clientIP != netip.Addr{} && clientIP == c.lastClientIP
}

func (cl *cipherList) SnapshotForClientIP(clientIP net.IP) []*list.Element {
func (cl *cipherList) SnapshotForClientIP(clientIP netip.Addr) []*list.Element {
cl.mu.RLock()
defer cl.mu.RUnlock()
cipherArray := make([]*list.Element, cl.list.Len())
Expand All @@ -102,7 +102,7 @@ func (cl *cipherList) SnapshotForClientIP(clientIP net.IP) []*list.Element {
return cipherArray
}

func (cl *cipherList) MarkUsedByClientIP(e *list.Element, clientIP net.IP) {
func (cl *cipherList) MarkUsedByClientIP(e *list.Element, clientIP netip.Addr) {
cl.mu.Lock()
defer cl.mu.Unlock()
cl.list.MoveToFront(e)
Expand Down
12 changes: 6 additions & 6 deletions service/cipher_list_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,18 @@ package service

import (
"math/rand"
"net"
"net/netip"
"testing"
)

func BenchmarkLocking(b *testing.B) {
var ip net.IP
var ip netip.Addr

ciphers, _ := MakeTestCiphers([]string{"secret"})
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
entries := ciphers.SnapshotForClientIP(nil)
entries := ciphers.SnapshotForClientIP(netip.Addr{})
ciphers.MarkUsedByClientIP(entries[0], ip)
}
})
Expand All @@ -43,20 +43,20 @@ func BenchmarkSnapshot(b *testing.B) {

// Shuffling simulates the behavior of a real server, where successive
// ciphers are not expected to be nearby in memory.
entries := ciphers.SnapshotForClientIP(nil)
entries := ciphers.SnapshotForClientIP(netip.Addr{})
rand.Shuffle(N, func(i, j int) {
entries[i], entries[j] = entries[j], entries[i]
})
for _, entry := range entries {
// Reorder the list to match the shuffle
// (actually in reverse, but it doesn't matter).
ciphers.MarkUsedByClientIP(entry, nil)
ciphers.MarkUsedByClientIP(entry, netip.Addr{})
}

b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
ciphers.SnapshotForClientIP(nil)
ciphers.SnapshotForClientIP(netip.Addr{})
}
})
}
25 changes: 13 additions & 12 deletions service/tcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"fmt"
"io"
"net"
"net/netip"
"sync"
"syscall"
"time"
Expand All @@ -46,19 +47,19 @@ type TCPMetrics interface {
AddTCPProbe(status, drainResult string, port int, clientProxyBytes int64)
}

func remoteIP(conn net.Conn) net.IP {
func remoteIP(conn net.Conn) netip.Addr {
addr := conn.RemoteAddr()
if addr == nil {
return nil
return netip.Addr{}
}
if tcpaddr, ok := addr.(*net.TCPAddr); ok {
return tcpaddr.IP
return tcpaddr.AddrPort().Addr()
}
ipstr, _, err := net.SplitHostPort(addr.String())
addrPort, err := netip.ParseAddrPort(addr.String())
if err == nil {
return net.ParseIP(ipstr)
return addrPort.Addr()
}
return nil
return netip.Addr{}
}

// Wrapper for logger.Debugf during TCP access key searches.
Expand All @@ -76,7 +77,7 @@ func debugTCP(cipherID, template string, val interface{}) {
// required = saltSize + 2 + cipher.TagSize, the number of bytes needed to authenticate the connection.
const bytesForKeyFinding = 50

func findAccessKey(clientReader io.Reader, clientIP net.IP, cipherList CipherList) (*CipherEntry, io.Reader, []byte, time.Duration, error) {
func findAccessKey(clientReader io.Reader, clientIP netip.Addr, cipherList CipherList) (*CipherEntry, io.Reader, []byte, time.Duration, error) {
// We snapshot the list because it may be modified while we use it.
ciphers := cipherList.SnapshotForClientIP(clientIP)
firstBytes := make([]byte, bytesForKeyFinding)
Expand Down Expand Up @@ -264,7 +265,7 @@ func (h *tcpHandler) Handle(ctx context.Context, clientConn transport.StreamConn
measuredClientConn := metrics.MeasureConn(clientConn, &proxyMetrics.ProxyClient, &proxyMetrics.ClientProxy)
connStart := time.Now()

id, connError := h.handleConnection(ctx, h.port, clientInfo, measuredClientConn, &proxyMetrics)
id, connError := h.handleConnection(ctx, measuredClientConn, &proxyMetrics)

connDuration := time.Since(connStart)
status := "OK"
Expand Down Expand Up @@ -327,7 +328,7 @@ func proxyConnection(ctx context.Context, dialer transport.StreamDialer, tgtAddr
return nil
}

func (h *tcpHandler) handleConnection(ctx context.Context, listenerPort int, clientInfo ipinfo.IPInfo, outerConn transport.StreamConn, proxyMetrics *metrics.ProxyMetrics) (string, *onet.ConnectionError) {
func (h *tcpHandler) handleConnection(ctx context.Context, outerConn transport.StreamConn, proxyMetrics *metrics.ProxyMetrics) (string, *onet.ConnectionError) {
// Set a deadline to receive the address to the target.
readDeadline := time.Now().Add(h.readTimeout)
if deadline, ok := ctx.Deadline(); ok {
Expand All @@ -341,7 +342,7 @@ func (h *tcpHandler) handleConnection(ctx context.Context, listenerPort int, cli
id, innerConn, authErr := h.authenticate(outerConn)
if authErr != nil {
// Drain to protect against probing attacks.
h.absorbProbe(listenerPort, outerConn, authErr.Status, proxyMetrics)
h.absorbProbe(outerConn, authErr.Status, proxyMetrics)
return id, authErr
}
h.m.AddAuthenticatedTCPConnection(outerConn.RemoteAddr(), id)
Expand Down Expand Up @@ -369,12 +370,12 @@ func (h *tcpHandler) handleConnection(ctx context.Context, listenerPort int, cli

// Keep the connection open until we hit the authentication deadline to protect against probing attacks
// `proxyMetrics` is a pointer because its value is being mutated by `clientConn`.
func (h *tcpHandler) absorbProbe(listenerPort int, clientConn io.ReadCloser, status string, proxyMetrics *metrics.ProxyMetrics) {
func (h *tcpHandler) absorbProbe(clientConn io.ReadCloser, status string, proxyMetrics *metrics.ProxyMetrics) {
// This line updates proxyMetrics.ClientProxy before it's used in AddTCPProbe.
_, drainErr := io.Copy(io.Discard, clientConn) // drain socket
drainResult := drainErrToString(drainErr)
logger.Debugf("Drain error: %v, drain result: %v", drainErr, drainResult)
h.m.AddTCPProbe(status, drainResult, listenerPort, proxyMetrics.ClientProxy)
h.m.AddTCPProbe(status, drainResult, h.port, proxyMetrics.ClientProxy)
}

func drainErrToString(drainErr error) string {
Expand Down
20 changes: 9 additions & 11 deletions service/tcp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"io"
"math/rand"
"net"
"net/netip"
"sync"
"testing"
"time"
Expand Down Expand Up @@ -99,7 +100,7 @@ func BenchmarkTCPFindCipherFail(b *testing.B) {
if err != nil {
b.Fatalf("AcceptTCP failed: %v", err)
}
clientIP := clientConn.RemoteAddr().(*net.TCPAddr).IP
clientIP := clientConn.RemoteAddr().(*net.TCPAddr).AddrPort().Addr()
b.StartTimer()
findAccessKey(clientConn, clientIP, cipherList)
b.StopTimer()
Expand Down Expand Up @@ -191,16 +192,16 @@ func BenchmarkTCPFindCipherRepeat(b *testing.B) {
b.Fatal(err)
}
cipherEntries := [numCiphers]*CipherEntry{}
snapshot := cipherList.SnapshotForClientIP(nil)
snapshot := cipherList.SnapshotForClientIP(netip.Addr{})
for cipherNumber, element := range snapshot {
cipherEntries[cipherNumber] = element.Value.(*CipherEntry)
}
for n := 0; n < b.N; n++ {
cipherNumber := byte(n % numCiphers)
reader, writer := io.Pipe()
clientIP := net.IPv4(192, 0, 2, cipherNumber)
addr := &net.TCPAddr{IP: clientIP, Port: 54321}
c := conn{clientAddr: addr, reader: reader, writer: writer}
clientIP := netip.AddrFrom4([4]byte{192, 0, 2, cipherNumber})
addr := netip.AddrPortFrom(clientIP, 54321)
c := conn{clientAddr: net.TCPAddrFromAddrPort(addr), reader: reader, writer: writer}
cipher := cipherEntries[cipherNumber].CryptoKey
go shadowsocks.NewWriter(writer, cipher).Write(makeTestPayload(50))
b.StartTimer()
Expand Down Expand Up @@ -345,7 +346,7 @@ func makeClientBytesCoalesced(t *testing.T, cryptoKey *shadowsocks.EncryptionKey
}

func firstCipher(cipherList CipherList) *shadowsocks.EncryptionKey {
snapshot := cipherList.SnapshotForClientIP(nil)
snapshot := cipherList.SnapshotForClientIP(netip.Addr{})
cipherEntry := snapshot[0].Value.(*CipherEntry)
return cipherEntry.CryptoKey
}
Expand All @@ -368,7 +369,6 @@ func TestProbeClientBytesBasicTruncated(t *testing.T) {
discardListener, discardWait := startDiscardServer(t)
initialBytes := makeClientBytesBasic(t, cipher, discardListener.Addr().String())
for numBytesToSend := 0; numBytesToSend < len(initialBytes); numBytesToSend++ {
t.Logf("Sending %v bytes", numBytesToSend)
bytesToSend := initialBytes[:numBytesToSend]
err := probe(listener.Addr().(*net.TCPAddr), bytesToSend)
require.NoError(t, err, "Failed for %v bytes sent: %v", numBytesToSend, err)
Expand Down Expand Up @@ -405,7 +405,6 @@ func TestProbeClientBytesBasicModified(t *testing.T) {
initialBytes := makeClientBytesBasic(t, cipher, discardListener.Addr().String())
bytesToSend := make([]byte, len(initialBytes))
for byteToModify := 0; byteToModify < len(initialBytes); byteToModify++ {
t.Logf("Modifying byte %v", byteToModify)
copy(bytesToSend, initialBytes)
bytesToSend[byteToModify] = 255 - bytesToSend[byteToModify]
err := probe(listener.Addr().(*net.TCPAddr), bytesToSend)
Expand Down Expand Up @@ -442,7 +441,6 @@ func TestProbeClientBytesCoalescedModified(t *testing.T) {
initialBytes := makeClientBytesCoalesced(t, cipher, discardListener.Addr().String())
bytesToSend := make([]byte, len(initialBytes))
for byteToModify := 0; byteToModify < len(initialBytes); byteToModify++ {
t.Logf("Modifying byte %v", byteToModify)
copy(bytesToSend, initialBytes)
bytesToSend[byteToModify] = 255 - bytesToSend[byteToModify]
err := probe(listener.Addr().(*net.TCPAddr), bytesToSend)
Expand Down Expand Up @@ -506,7 +504,7 @@ func TestReplayDefense(t *testing.T) {
const testTimeout = 200 * time.Millisecond
authFunc := NewShadowsocksStreamAuthenticator(cipherList, &replayCache, testMetrics)
handler := NewTCPHandler(listener.Addr().(*net.TCPAddr).Port, authFunc, testMetrics, testTimeout)
snapshot := cipherList.SnapshotForClientIP(nil)
snapshot := cipherList.SnapshotForClientIP(netip.Addr{})
cipherEntry := snapshot[0].Value.(*CipherEntry)
cipher := cipherEntry.CryptoKey
reader, writer := io.Pipe()
Expand Down Expand Up @@ -585,7 +583,7 @@ func TestReverseReplayDefense(t *testing.T) {
const testTimeout = 200 * time.Millisecond
authFunc := NewShadowsocksStreamAuthenticator(cipherList, &replayCache, testMetrics)
handler := NewTCPHandler(listener.Addr().(*net.TCPAddr).Port, authFunc, testMetrics, testTimeout)
snapshot := cipherList.SnapshotForClientIP(nil)
snapshot := cipherList.SnapshotForClientIP(netip.Addr{})
cipherEntry := snapshot[0].Value.(*CipherEntry)
cipher := cipherEntry.CryptoKey
reader, writer := io.Pipe()
Expand Down
5 changes: 3 additions & 2 deletions service/udp.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"errors"
"fmt"
"net"
"net/netip"
"runtime/debug"
"sync"
"time"
Expand Down Expand Up @@ -64,7 +65,7 @@ func debugUDPAddr(addr net.Addr, template string, val interface{}) {

// Decrypts src into dst. It tries each cipher until it finds one that authenticates
// correctly. dst and src must not overlap.
func findAccessKeyUDP(clientIP net.IP, dst, src []byte, cipherList CipherList) ([]byte, string, *shadowsocks.EncryptionKey, error) {
func findAccessKeyUDP(clientIP netip.Addr, dst, src []byte, cipherList CipherList) ([]byte, string, *shadowsocks.EncryptionKey, error) {
// Try each cipher until we find one that authenticates successfully. This assumes that all ciphers are AEAD.
// We snapshot the list because it may be modified while we use it.
snapshot := cipherList.SnapshotForClientIP(clientIP)
Expand Down Expand Up @@ -156,7 +157,7 @@ func (h *packetHandler) Handle(clientConn net.PacketConn) {
}
debugUDPAddr(clientAddr, "Got info \"%#v\"", clientInfo)

ip := clientAddr.(*net.UDPAddr).IP
ip := clientAddr.(*net.UDPAddr).AddrPort().Addr()
var textData []byte
var cryptoKey *shadowsocks.EncryptionKey
unpackStart := time.Now()
Expand Down
Loading
Loading