diff --git a/session/mocks/session.go b/session/mocks/session.go index a5b06b81..8863d48f 100644 --- a/session/mocks/session.go +++ b/session/mocks/session.go @@ -732,6 +732,20 @@ func (mr *MockSessionPoolMockRecorder) CloseAll() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseAll", reflect.TypeOf((*MockSessionPool)(nil).CloseAll)) } +// GetNumberOfConnectedClients mocks base method. +func (m *MockSessionPool) GetNumberOfConnectedClients() int64 { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetNumberOfConnectedClients") + ret0, _ := ret[0].(int64) + return ret0 +} + +// GetNumberOfConnectedClients indicates an expected call of GetNumberOfConnectedClients. +func (mr *MockSessionPoolMockRecorder) GetNumberOfConnectedClients() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetNumberOfConnectedClients", reflect.TypeOf((*MockSessionPool)(nil).GetNumberOfConnectedClients)) +} + // GetSessionByID mocks base method. func (m *MockSessionPool) GetSessionByID(arg0 int64) session.Session { m.ctrl.T.Helper() diff --git a/session/session.go b/session/session.go index 5d95b141..ac0874f0 100644 --- a/session/session.go +++ b/session/session.go @@ -64,6 +64,7 @@ type SessionPool interface { OnSessionClose(f func(s Session)) CloseAll() AddHandshakeValidator(name string, f func(data *HandshakeData) error) + GetNumberOfConnectedClients() int64 } // HandshakeClientData represents information about the client sent on the handshake. @@ -310,6 +311,11 @@ func (pool *sessionPoolImpl) AddHandshakeValidator(name string, f func(data *Han pool.handshakeValidators[name] = f } +// GetNumberOfConnectedClients returns the number of connected clients +func (pool *sessionPoolImpl) GetNumberOfConnectedClients() int64 { + return pool.GetSessionCount() +} + func (s *sessionImpl) updateEncodedData() error { var b []byte b, err := json.Marshal(s.data) diff --git a/session/session_test.go b/session/session_test.go index 6d288317..ab437859 100644 --- a/session/session_test.go +++ b/session/session_test.go @@ -1507,3 +1507,18 @@ func TestSessionValidateHandshake(t *testing.T) { }) } } + +func TestGetNumberOfConnectedClients(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + entity := mocks.NewMockNetworkEntity(ctrl) + sessionPool := NewSessionPool() + connections := sessionPool.GetNumberOfConnectedClients() + assert.Equal(t, int64(0), connections) + + ss := sessionPool.NewSession(entity, true) + assert.NotNil(t, ss) + connections = sessionPool.GetNumberOfConnectedClients() + assert.Equal(t, int64(1), connections) +} diff --git a/session/static.go b/session/static.go index 4a0b19e7..4dfb463e 100644 --- a/session/static.go +++ b/session/static.go @@ -36,3 +36,8 @@ func OnSessionClose(f func(s Session)) { func CloseAll() { DefaultSessionPool.CloseAll() } + +// GetNumberOfConnectedClients returns the number of connected clients +func GetNumberOfConnectedClients() int64 { + return DefaultSessionPool.GetSessionCount() +}