From 068163856833891a6e46a745197044a84ec6a727 Mon Sep 17 00:00:00 2001 From: Haruue Date: Fri, 8 Nov 2024 15:28:50 +0900 Subject: [PATCH 1/3] feat(trafficlogger): dump streams stats --- core/internal/utils/atomic.go | 30 +++++++ core/server/config.go | 65 +++++++++++++++ core/server/copy.go | 7 +- core/server/server.go | 32 +++++++- extras/trafficlogger/http.go | 147 ++++++++++++++++++++++++++++++++++ 5 files changed, 277 insertions(+), 4 deletions(-) diff --git a/core/internal/utils/atomic.go b/core/internal/utils/atomic.go index e3c3d97782..7739013ec0 100644 --- a/core/internal/utils/atomic.go +++ b/core/internal/utils/atomic.go @@ -22,3 +22,33 @@ func (t *AtomicTime) Set(new time.Time) { func (t *AtomicTime) Get() time.Time { return t.v.Load().(time.Time) } + +type Atomic[T any] struct { + v atomic.Value +} + +func (a *Atomic[T]) Load() T { + value := a.v.Load() + if value == nil { + var zero T + return zero + } + return value.(T) +} + +func (a *Atomic[T]) Store(value T) { + a.v.Store(value) +} + +func (a *Atomic[T]) Swap(new T) T { + old := a.v.Swap(new) + if old == nil { + var zero T + return zero + } + return old.(T) +} + +func (a *Atomic[T]) CompareAndSwap(old, new T) bool { + return a.v.CompareAndSwap(old, new) +} diff --git a/core/server/config.go b/core/server/config.go index f90c820557..19aae53bd0 100644 --- a/core/server/config.go +++ b/core/server/config.go @@ -4,8 +4,11 @@ import ( "crypto/tls" "net" "net/http" + "sync/atomic" "time" + "github.com/apernet/hysteria/core/v2/internal/utils" + "github.com/apernet/hysteria/core/v2/errors" "github.com/apernet/hysteria/core/v2/internal/pmtud" "github.com/apernet/quic-go" @@ -212,4 +215,66 @@ type EventLogger interface { type TrafficLogger interface { LogTraffic(id string, tx, rx uint64) (ok bool) LogOnlineState(id string, online bool) + TraceStream(stream quic.Stream, stats *StreamStats) + UntraceStream(stream quic.Stream) +} + +type StreamState int + +const ( + // StreamStateInitial indicates the initial state of a stream. + // Client has opened the stream, but we have not received the proxy request yet. + StreamStateInitial StreamState = iota + + // StreamStateHooking indicates that the hook (usually sniff) is processing. + // Client has sent the proxy request, but sniff requires more data to complete. + StreamStateHooking + + // StreamStateConnecting indicates that we are connecting to the proxy target. + StreamStateConnecting + + // StreamStateEstablished indicates the proxy is established. + StreamStateEstablished + + // StreamStateClosed indicates the stream is closed. + StreamStateClosed +) + +func (s StreamState) String() string { + switch s { + case StreamStateInitial: + return "init" + case StreamStateHooking: + return "hook" + case StreamStateConnecting: + return "connect" + case StreamStateEstablished: + return "estab" + case StreamStateClosed: + return "closed" + default: + return "unknown" + } +} + +type StreamStats struct { + State utils.Atomic[StreamState] + + AuthID string + ConnID uint32 + InitialTime time.Time + + ReqAddr utils.Atomic[string] + HookedReqAddr utils.Atomic[string] + + Tx atomic.Uint64 + Rx atomic.Uint64 + + LastActiveTime utils.Atomic[time.Time] +} + +func (s *StreamStats) setHookedReqAddr(addr string) { + if addr != s.ReqAddr.Load() { + s.HookedReqAddr.Store(addr) + } } diff --git a/core/server/copy.go b/core/server/copy.go index d55dcefe3f..2f99ae42cf 100644 --- a/core/server/copy.go +++ b/core/server/copy.go @@ -3,6 +3,7 @@ package server import ( "errors" "io" + "time" ) var errDisconnect = errors.New("traffic logger requested disconnect") @@ -31,15 +32,19 @@ func copyBufferLog(dst io.Writer, src io.Reader, log func(n uint64) bool) error } } -func copyTwoWayWithLogger(id string, serverRw, remoteRw io.ReadWriter, l TrafficLogger) error { +func copyTwoWayWithLogger(id string, serverRw, remoteRw io.ReadWriter, l TrafficLogger, stats *StreamStats) error { errChan := make(chan error, 2) go func() { errChan <- copyBufferLog(serverRw, remoteRw, func(n uint64) bool { + stats.LastActiveTime.Store(time.Now()) + stats.Rx.Add(n) return l.LogTraffic(id, 0, n) }) }() go func() { errChan <- copyBufferLog(remoteRw, serverRw, func(n uint64) bool { + stats.LastActiveTime.Store(time.Now()) + stats.Tx.Add(n) return l.LogTraffic(id, n, 0) }) }() diff --git a/core/server/server.go b/core/server/server.go index ba55b315b6..f7ad957f81 100644 --- a/core/server/server.go +++ b/core/server/server.go @@ -3,8 +3,10 @@ package server import ( "context" "crypto/tls" + "math/rand" "net/http" "sync" + "time" "github.com/apernet/quic-go" "github.com/apernet/quic-go/http3" @@ -100,6 +102,7 @@ type h3sHandler struct { authenticated bool authMutex sync.Mutex authID string + connID uint32 // a random id for dump streams udpSM *udpSessionManager // Only set after authentication } @@ -108,6 +111,7 @@ func newH3sHandler(config *Config, conn quic.Connection) *h3sHandler { return &h3sHandler{ config: config, conn: conn, + connID: rand.Uint32(), } } @@ -205,12 +209,29 @@ func (h *h3sHandler) ProxyStreamHijacker(ft http3.FrameType, id quic.ConnectionT } func (h *h3sHandler) handleTCPRequest(stream quic.Stream) { + trafficLogger := h.config.TrafficLogger + streamStats := &StreamStats{ + AuthID: h.authID, + ConnID: h.connID, + InitialTime: time.Now(), + } + streamStats.State.Store(StreamStateInitial) + streamStats.LastActiveTime.Store(time.Now()) + defer func() { + streamStats.State.Store(StreamStateClosed) + }() + if trafficLogger != nil { + trafficLogger.TraceStream(stream, streamStats) + defer trafficLogger.UntraceStream(stream) + } + // Read request reqAddr, err := protocol.ReadTCPRequest(stream) if err != nil { _ = stream.Close() return } + streamStats.ReqAddr.Store(reqAddr) // Call the hook if set var putback []byte var hooked bool @@ -220,12 +241,14 @@ func (h *h3sHandler) handleTCPRequest(stream quic.Stream) { // so that the client will send whatever request the hook wants to see. // This is essentially a server-side fast-open. if hooked { + streamStats.State.Store(StreamStateHooking) _ = protocol.WriteTCPResponse(stream, true, "RequestHook enabled") putback, err = h.config.RequestHook.TCP(stream, &reqAddr) if err != nil { _ = stream.Close() return } + streamStats.setHookedReqAddr(reqAddr) } } // Log the event @@ -233,6 +256,7 @@ func (h *h3sHandler) handleTCPRequest(stream quic.Stream) { h.config.EventLogger.TCPRequest(h.conn.RemoteAddr(), h.authID, reqAddr) } // Dial target + streamStats.State.Store(StreamStateConnecting) tConn, err := h.config.Outbound.TCP(reqAddr) if err != nil { if !hooked { @@ -248,13 +272,15 @@ func (h *h3sHandler) handleTCPRequest(stream quic.Stream) { if !hooked { _ = protocol.WriteTCPResponse(stream, true, "Connected") } + streamStats.State.Store(StreamStateEstablished) // Put back the data if the hook requested if len(putback) > 0 { - _, _ = tConn.Write(putback) + n, _ := tConn.Write(putback) + streamStats.Tx.Add(uint64(n)) } // Start proxying - if h.config.TrafficLogger != nil { - err = copyTwoWayWithLogger(h.authID, stream, tConn, h.config.TrafficLogger) + if trafficLogger != nil { + err = copyTwoWayWithLogger(h.authID, stream, tConn, trafficLogger, streamStats) } else { // Use the fast path if no traffic logger is set err = copyTwoWay(stream, tConn) diff --git a/extras/trafficlogger/http.go b/extras/trafficlogger/http.go index 9ab943af17..87eb919f14 100644 --- a/extras/trafficlogger/http.go +++ b/extras/trafficlogger/http.go @@ -1,10 +1,17 @@ package trafficlogger import ( + "cmp" "encoding/json" + "fmt" "net/http" + "slices" "strconv" + "strings" "sync" + "time" + + "github.com/apernet/quic-go" "github.com/apernet/hysteria/core/v2/server" ) @@ -25,6 +32,7 @@ func NewTrafficStatsServer(secret string) TrafficStatsServer { StatsMap: make(map[string]*trafficStatsEntry), KickMap: make(map[string]struct{}), OnlineMap: make(map[string]int), + StreamMap: make(map[quic.Stream]*server.StreamStats), Secret: secret, } } @@ -33,6 +41,7 @@ type trafficStatsServerImpl struct { Mutex sync.RWMutex StatsMap map[string]*trafficStatsEntry OnlineMap map[string]int + StreamMap map[quic.Stream]*server.StreamStats KickMap map[string]struct{} Secret string } @@ -78,6 +87,20 @@ func (s *trafficStatsServerImpl) LogOnlineState(id string, online bool) { } } +func (s *trafficStatsServerImpl) TraceStream(stream quic.Stream, stats *server.StreamStats) { + s.Mutex.Lock() + defer s.Mutex.Unlock() + + s.StreamMap[stream] = stats +} + +func (s *trafficStatsServerImpl) UntraceStream(stream quic.Stream) { + s.Mutex.Lock() + defer s.Mutex.Unlock() + + delete(s.StreamMap, stream) +} + func (s *trafficStatsServerImpl) ServeHTTP(w http.ResponseWriter, r *http.Request) { if s.Secret != "" && r.Header.Get("Authorization") != s.Secret { http.Error(w, "unauthorized", http.StatusUnauthorized) @@ -99,6 +122,10 @@ func (s *trafficStatsServerImpl) ServeHTTP(w http.ResponseWriter, r *http.Reques s.getOnline(w, r) return } + if r.Method == http.MethodGet && r.URL.Path == "/dump/streams" { + s.getDumpStreams(w, r) + return + } http.NotFound(w, r) } @@ -137,6 +164,126 @@ func (s *trafficStatsServerImpl) getOnline(w http.ResponseWriter, r *http.Reques _, _ = w.Write(jb) } +type dumpStreamEntry struct { + State string `json:"state"` + + Auth string `json:"auth"` + Connection uint32 `json:"connection"` + Stream uint64 `json:"stream"` + + ReqAddr string `json:"req_addr"` + HookedReqAddr string `json:"hooked_req_addr"` + + Tx uint64 `json:"tx"` + Rx uint64 `json:"rx"` + + InitialAt string `json:"initial_at"` + LastActiveAt string `json:"last_active_at"` + + // for text/plain output + initialTime time.Time + lastActiveTime time.Time +} + +func (e *dumpStreamEntry) fromStreamStats(stream quic.Stream, s *server.StreamStats) { + e.State = s.State.Load().String() + e.Auth = s.AuthID + e.Connection = s.ConnID + e.Stream = uint64(stream.StreamID()) + e.ReqAddr = s.ReqAddr.Load() + e.HookedReqAddr = s.HookedReqAddr.Load() + e.Tx = s.Tx.Load() + e.Rx = s.Rx.Load() + e.initialTime = s.InitialTime + e.lastActiveTime = s.LastActiveTime.Load() + e.InitialAt = e.initialTime.Format(time.RFC3339Nano) + e.LastActiveAt = e.lastActiveTime.Format(time.RFC3339Nano) +} + +func formatDumpStreamLine(state, auth, connection, stream, reqAddr, hookedReqAddr, tx, rx, lifetime, lastActive string) string { + return fmt.Sprintf("%-8s %-12s %12s %8s %12s %12s %12s %12s %-16s %s", state, auth, connection, stream, tx, rx, lifetime, lastActive, reqAddr, hookedReqAddr) +} + +func (e *dumpStreamEntry) String() string { + stateText := strings.ToUpper(e.State) + connectionText := fmt.Sprintf("%08X", e.Connection) + streamText := strconv.FormatUint(e.Stream, 10) + reqAddrText := e.ReqAddr + if reqAddrText == "" { + reqAddrText = "-" + } + hookedReqAddrText := e.HookedReqAddr + if hookedReqAddrText == "" { + hookedReqAddrText = "-" + } + txText := strconv.FormatUint(e.Tx, 10) + rxText := strconv.FormatUint(e.Rx, 10) + lifetime := time.Now().Sub(e.initialTime) + if lifetime < 10*time.Minute { + lifetime = lifetime.Round(time.Millisecond) + } else { + lifetime = lifetime.Round(time.Second) + } + lastActive := time.Now().Sub(e.lastActiveTime) + if lastActive < 10*time.Minute { + lastActive = lastActive.Round(time.Millisecond) + } else { + lastActive = lastActive.Round(time.Second) + } + + return formatDumpStreamLine(stateText, e.Auth, connectionText, streamText, reqAddrText, hookedReqAddrText, txText, rxText, lifetime.String(), lastActive.String()) +} + +func (s *trafficStatsServerImpl) getDumpStreams(w http.ResponseWriter, r *http.Request) { + var entries []dumpStreamEntry + + s.Mutex.RLock() + entries = make([]dumpStreamEntry, len(s.StreamMap)) + index := 0 + for stream, stats := range s.StreamMap { + entries[index].fromStreamStats(stream, stats) + index++ + } + s.Mutex.RUnlock() + + slices.SortFunc(entries, func(lhs, rhs dumpStreamEntry) int { + if ret := cmp.Compare(lhs.Auth, rhs.Auth); ret != 0 { + return ret + } + if ret := cmp.Compare(lhs.Connection, rhs.Connection); ret != 0 { + return ret + } + if ret := cmp.Compare(lhs.Stream, rhs.Stream); ret != 0 { + return ret + } + return 0 + }) + + accept := r.Header.Get("Accept") + + if strings.Contains(accept, "text/plain") { + // Generate netstat-like output for humans + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + + // Print table header + _, _ = fmt.Fprintln(w, formatDumpStreamLine("State", "Auth", "Connection", "Stream", "Req-Addr", "Hooked-Req-Addr", "TX-Bytes", "RX-Bytes", "Lifetime", "Last-Active")) + for _, entry := range entries { + _, _ = fmt.Fprintln(w, entry.String()) + } + return + } + + // Response with json by default + wrapper := struct { + Streams []dumpStreamEntry `json:"streams"` + }{entries} + w.Header().Set("Content-Type", "application/json; charset=utf-8") + err := json.NewEncoder(w).Encode(&wrapper) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } +} + func (s *trafficStatsServerImpl) kick(w http.ResponseWriter, r *http.Request) { var ids []string err := json.NewDecoder(r.Body).Decode(&ids) From 7ac8d87ddae55af435738e4430ce6dcffc2ae140 Mon Sep 17 00:00:00 2001 From: Haruue Date: Fri, 8 Nov 2024 16:03:48 +0900 Subject: [PATCH 2/3] test: fix integration_tests for trafficlogger --- .../mocks/mock_TrafficLogger.go | 74 ++++++++++++++++++- .../integration_tests/trafficlogger_test.go | 2 + 2 files changed, 75 insertions(+), 1 deletion(-) diff --git a/core/internal/integration_tests/mocks/mock_TrafficLogger.go b/core/internal/integration_tests/mocks/mock_TrafficLogger.go index 9de44b976e..1ed977efd0 100644 --- a/core/internal/integration_tests/mocks/mock_TrafficLogger.go +++ b/core/internal/integration_tests/mocks/mock_TrafficLogger.go @@ -2,7 +2,12 @@ package mocks -import mock "github.com/stretchr/testify/mock" +import ( + quic "github.com/apernet/quic-go" + mock "github.com/stretchr/testify/mock" + + server "github.com/apernet/hysteria/core/v2/server" +) // MockTrafficLogger is an autogenerated mock type for the TrafficLogger type type MockTrafficLogger struct { @@ -99,6 +104,73 @@ func (_c *MockTrafficLogger_LogTraffic_Call) RunAndReturn(run func(string, uint6 return _c } +// TraceStream provides a mock function with given fields: stream, stats +func (_m *MockTrafficLogger) TraceStream(stream quic.Stream, stats *server.StreamStats) { + _m.Called(stream, stats) +} + +// MockTrafficLogger_TraceStream_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'TraceStream' +type MockTrafficLogger_TraceStream_Call struct { + *mock.Call +} + +// TraceStream is a helper method to define mock.On call +// - stream quic.Stream +// - stats *server.StreamStats +func (_e *MockTrafficLogger_Expecter) TraceStream(stream interface{}, stats interface{}) *MockTrafficLogger_TraceStream_Call { + return &MockTrafficLogger_TraceStream_Call{Call: _e.mock.On("TraceStream", stream, stats)} +} + +func (_c *MockTrafficLogger_TraceStream_Call) Run(run func(stream quic.Stream, stats *server.StreamStats)) *MockTrafficLogger_TraceStream_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(quic.Stream), args[1].(*server.StreamStats)) + }) + return _c +} + +func (_c *MockTrafficLogger_TraceStream_Call) Return() *MockTrafficLogger_TraceStream_Call { + _c.Call.Return() + return _c +} + +func (_c *MockTrafficLogger_TraceStream_Call) RunAndReturn(run func(quic.Stream, *server.StreamStats)) *MockTrafficLogger_TraceStream_Call { + _c.Call.Return(run) + return _c +} + +// UntraceStream provides a mock function with given fields: stream +func (_m *MockTrafficLogger) UntraceStream(stream quic.Stream) { + _m.Called(stream) +} + +// MockTrafficLogger_UntraceStream_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UntraceStream' +type MockTrafficLogger_UntraceStream_Call struct { + *mock.Call +} + +// UntraceStream is a helper method to define mock.On call +// - stream quic.Stream +func (_e *MockTrafficLogger_Expecter) UntraceStream(stream interface{}) *MockTrafficLogger_UntraceStream_Call { + return &MockTrafficLogger_UntraceStream_Call{Call: _e.mock.On("UntraceStream", stream)} +} + +func (_c *MockTrafficLogger_UntraceStream_Call) Run(run func(stream quic.Stream)) *MockTrafficLogger_UntraceStream_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(quic.Stream)) + }) + return _c +} + +func (_c *MockTrafficLogger_UntraceStream_Call) Return() *MockTrafficLogger_UntraceStream_Call { + _c.Call.Return() + return _c +} + +func (_c *MockTrafficLogger_UntraceStream_Call) RunAndReturn(run func(quic.Stream)) *MockTrafficLogger_UntraceStream_Call { + _c.Call.Return(run) + return _c +} + // NewMockTrafficLogger creates a new instance of MockTrafficLogger. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. // The first argument is typically a *testing.T value. func NewMockTrafficLogger(t interface { diff --git a/core/internal/integration_tests/trafficlogger_test.go b/core/internal/integration_tests/trafficlogger_test.go index b5355ff694..841f4ffa8f 100644 --- a/core/internal/integration_tests/trafficlogger_test.go +++ b/core/internal/integration_tests/trafficlogger_test.go @@ -62,6 +62,7 @@ func TestClientServerTrafficLoggerTCP(t *testing.T) { return nil }) serverOb.EXPECT().TCP(addr).Return(sobConn, nil).Once() + trafficLogger.EXPECT().TraceStream(mock.Anything, mock.Anything).Return().Once() conn, err := c.TCP(addr) assert.NoError(t, err) @@ -84,6 +85,7 @@ func TestClientServerTrafficLoggerTCP(t *testing.T) { time.Sleep(1 * time.Second) // Need some time for the server to receive the data // Client reads from server again but blocked + trafficLogger.EXPECT().UntraceStream(mock.Anything).Return().Once() trafficLogger.EXPECT().LogTraffic("nobody", uint64(0), uint64(4)).Return(false).Once() trafficLogger.EXPECT().LogOnlineState("nobody", false).Return().Once() sobConnCh <- []byte("nope") From 3e8c20518db0e97ad67b638e85cbe643b26d777a Mon Sep 17 00:00:00 2001 From: Toby Date: Fri, 8 Nov 2024 14:29:50 -0800 Subject: [PATCH 3/3] chore: minor code tweaks --- core/server/config.go | 3 +-- core/server/copy.go | 4 ++-- core/server/server.go | 2 +- extras/trafficlogger/http.go | 3 +-- 4 files changed, 5 insertions(+), 7 deletions(-) diff --git a/core/server/config.go b/core/server/config.go index 19aae53bd0..a01f478f4d 100644 --- a/core/server/config.go +++ b/core/server/config.go @@ -7,10 +7,9 @@ import ( "sync/atomic" "time" - "github.com/apernet/hysteria/core/v2/internal/utils" - "github.com/apernet/hysteria/core/v2/errors" "github.com/apernet/hysteria/core/v2/internal/pmtud" + "github.com/apernet/hysteria/core/v2/internal/utils" "github.com/apernet/quic-go" ) diff --git a/core/server/copy.go b/core/server/copy.go index 2f99ae42cf..7123fc89ea 100644 --- a/core/server/copy.go +++ b/core/server/copy.go @@ -32,7 +32,7 @@ func copyBufferLog(dst io.Writer, src io.Reader, log func(n uint64) bool) error } } -func copyTwoWayWithLogger(id string, serverRw, remoteRw io.ReadWriter, l TrafficLogger, stats *StreamStats) error { +func copyTwoWayEx(id string, serverRw, remoteRw io.ReadWriter, l TrafficLogger, stats *StreamStats) error { errChan := make(chan error, 2) go func() { errChan <- copyBufferLog(serverRw, remoteRw, func(n uint64) bool { @@ -52,7 +52,7 @@ func copyTwoWayWithLogger(id string, serverRw, remoteRw io.ReadWriter, l Traffic return <-errChan } -// copyTwoWay is the "fast-path" version of copyTwoWayWithLogger that does not log traffic. +// copyTwoWay is the "fast-path" version of copyTwoWayEx that does not log traffic or update stream stats. // It uses the built-in io.Copy instead of our own copyBufferLog. func copyTwoWay(serverRw, remoteRw io.ReadWriter) error { errChan := make(chan error, 2) diff --git a/core/server/server.go b/core/server/server.go index f7ad957f81..696f1d0956 100644 --- a/core/server/server.go +++ b/core/server/server.go @@ -280,7 +280,7 @@ func (h *h3sHandler) handleTCPRequest(stream quic.Stream) { } // Start proxying if trafficLogger != nil { - err = copyTwoWayWithLogger(h.authID, stream, tConn, trafficLogger, streamStats) + err = copyTwoWayEx(h.authID, stream, tConn, trafficLogger, streamStats) } else { // Use the fast path if no traffic logger is set err = copyTwoWay(stream, tConn) diff --git a/extras/trafficlogger/http.go b/extras/trafficlogger/http.go index 87eb919f14..d8e6ebd4f1 100644 --- a/extras/trafficlogger/http.go +++ b/extras/trafficlogger/http.go @@ -11,9 +11,8 @@ import ( "sync" "time" - "github.com/apernet/quic-go" - "github.com/apernet/hysteria/core/v2/server" + "github.com/apernet/quic-go" ) const (