diff --git a/ntrack/listener.go b/ntrack/listener.go new file mode 100644 index 0000000..5bcfa82 --- /dev/null +++ b/ntrack/listener.go @@ -0,0 +1,92 @@ +package ntrack + +import ( + "context" + "fmt" + "net" + "sync/atomic" + + "github.com/pkg/errors" + "go.opencensus.io/stats" + "go.opencensus.io/stats/view" + "go.opencensus.io/tag" +) + +type trackingListener struct { + net.Listener + stats *Stats +} + +func NewInstrumentedListener(lis net.Listener) (net.Listener, *Stats) { + listenerStats := &Stats{} + listenerStats.init() + + return &trackingListener{ + Listener: lis, + stats: listenerStats, + }, listenerStats +} + +func (tl *trackingListener) Accept() (net.Conn, error) { + conn, err := tl.Listener.Accept() + stats.RecordWithTags(context.TODO(), []tag.Mutator{tag.Upsert(tl.stats.TagSuccess, fmt.Sprintf("%v", err == nil))}, tl.stats.ListenerAccepted.M(1)) + if err != nil { + return nil, errors.Wrap(err, "accept from base listener") + } + + open := atomic.AddInt64(&tl.stats.openConnections, 1) + stats.Record(context.TODO(), tl.stats.OpenConnections.M(open)) + return &serverConn{Conn: conn, stats: tl.stats}, nil +} + +type serverConn struct { + net.Conn + stats *Stats +} + +func (sc *serverConn) Close() error { + err := sc.Conn.Close() + open := atomic.AddInt64(&sc.stats.openConnections, -1) + stats.Record(context.TODO(), + sc.stats.OpenConnections.M(open), + sc.stats.LifetimeClosedConnections.M(1), + ) + return errors.Wrap(err, "close server conn") +} + +type Stats struct { + ListenerAccepted *stats.Int64Measure + LifetimeClosedConnections *stats.Int64Measure + OpenConnections *stats.Int64Measure + openConnections int64 + + TagSuccess tag.Key + + ListenerAcceptedView *view.View + LifetimeClosedConnectionsView *view.View + OpenConnectionsView *view.View +} + +func (s *Stats) init() { + s.ListenerAccepted = stats.Int64("ntrack/listener/accepts", "The number of Accept calls on the net.Listener", stats.UnitDimensionless) + s.LifetimeClosedConnections = stats.Int64("ntrack/listener/closed", "The number of Close calls on the net.Listener", stats.UnitDimensionless) + s.OpenConnections = stats.Int64("ntrack/listener/open", "The number of Open connections from the net.Listener", stats.UnitDimensionless) + + s.TagSuccess, _ = tag.NewKey("success") + + tags := []tag.Key{s.TagSuccess} + + s.ListenerAcceptedView = viewFromStat(s.ListenerAccepted, tags, view.Count()) + s.OpenConnectionsView = viewFromStat(s.OpenConnections, nil, view.LastValue()) + s.LifetimeClosedConnectionsView = viewFromStat(s.LifetimeClosedConnections, nil, view.Count()) +} + +func viewFromStat(ss *stats.Int64Measure, tags []tag.Key, agg *view.Aggregation) *view.View { + return &view.View{ + Name: ss.Name(), + Measure: ss, + Description: ss.Description(), + TagKeys: tags, + Aggregation: agg, + } +} diff --git a/ntrack/listener_test.go b/ntrack/listener_test.go new file mode 100644 index 0000000..eee1f46 --- /dev/null +++ b/ntrack/listener_test.go @@ -0,0 +1,113 @@ +package ntrack + +import ( + "fmt" + "net" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.opencensus.io/stats/view" +) + +func TestListener(t *testing.T) { + var tests = []struct { + viewName string + disableKeepalive bool + expectedValue int64 + }{ + { + viewName: "ntrack/listener/accepts", + disableKeepalive: true, + expectedValue: 5, + }, + { + viewName: "ntrack/listener/accepts", + disableKeepalive: false, + expectedValue: 1, + }, + { + viewName: "ntrack/listener/closed", + disableKeepalive: true, + expectedValue: 5, + }, + { + viewName: "ntrack/listener/open", + disableKeepalive: true, + expectedValue: 0, + }, + { + viewName: "ntrack/listener/open", + disableKeepalive: false, + expectedValue: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.viewName, func(t *testing.T) { + lis, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + ilis, stats := NewInstrumentedListener(lis) + registerViewByName(t, tt.viewName, stats, false) + + testClientConnections(t, ilis, tt.disableKeepalive) + + rows, err := view.RetrieveData(tt.viewName) + require.NoError(t, err) + + switch data := rows[0].Data.(type) { + case *view.CountData: + assert.Equal(t, tt.expectedValue, data.Value) + case *view.LastValueData: + assert.Equal(t, float64(tt.expectedValue), data.Value) + } + registerViewByName(t, tt.viewName, stats, true) + }) + } +} + +func registerViewByName(t *testing.T, name string, stats *Stats, unregister bool) { + var v *view.View + switch name { + case "ntrack/listener/accepts": + v = stats.ListenerAcceptedView + case "ntrack/listener/open": + v = stats.OpenConnectionsView + case "ntrack/listener/closed": + v = stats.LifetimeClosedConnectionsView + } + if unregister { + view.Unregister(v) + } else { + view.Register(v) + } +} + +func testClientConnections(t *testing.T, lis net.Listener, disableKeepalive bool) { + t.Helper() + + srv := &http.Server{ + Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }), + } + + go func() { + if err := srv.Serve(lis); err != nil { + t.Fatal(err) + } + }() + + tr := &http.Transport{DisableKeepAlives: disableKeepalive} + client := &http.Client{Transport: tr} + + requestCount := 5 + for i := 0; i < requestCount; i++ { + resp, err := client.Get(fmt.Sprintf("http://%s", lis.Addr())) + require.NoError(t, err) + resp.Body.Close() + } + +}