From 77bf07d34987b827b24ba817811c3a15ab62a025 Mon Sep 17 00:00:00 2001 From: Mike Date: Sun, 2 Jun 2024 10:55:42 -0400 Subject: [PATCH] decouple auth service from bos service --- .mockery.yaml | 8 +- cmd/server/main.go | 45 +++-- config/settings.env | 2 +- foodgroup/auth.go | 55 ++---- foodgroup/auth_test.go | 236 +++++++++-------------- foodgroup/chat_nav.go | 4 +- foodgroup/chat_nav_test.go | 4 +- foodgroup/mock_cookie_issuer_test.go | 90 +++++++++ foodgroup/mock_session_manager_test.go | 21 +- foodgroup/oservice.go | 41 ++-- foodgroup/oservice_test.go | 92 ++++++--- foodgroup/test_helpers.go | 23 +-- foodgroup/types.go | 6 +- server/oscar/alert.go | 111 ----------- server/oscar/alert_test.go | 109 ----------- server/oscar/auth.go | 13 +- server/oscar/auth_test.go | 2 +- server/oscar/bos.go | 21 +- server/oscar/bos_test.go | 15 +- server/oscar/chat.go | 8 +- server/oscar/chat_nav.go | 106 ---------- server/oscar/chat_nav_test.go | 108 ----------- server/oscar/chat_test.go | 9 +- server/oscar/connection_test.go | 4 +- server/oscar/mock_auth_test.go | 102 +++++----- server/oscar/mock_cookie_cracker_test.go | 90 +++++++++ server/oscar/types.go | 5 + state/cookie.go | 111 +++++++++++ state/session.go | 15 -- state/session_manager.go | 9 +- state/session_manager_test.go | 71 ++++--- state/session_test.go | 9 - 32 files changed, 701 insertions(+), 844 deletions(-) create mode 100644 foodgroup/mock_cookie_issuer_test.go delete mode 100644 server/oscar/alert.go delete mode 100644 server/oscar/alert_test.go delete mode 100644 server/oscar/chat_nav.go delete mode 100644 server/oscar/chat_nav_test.go create mode 100644 server/oscar/mock_cookie_cracker_test.go create mode 100644 server/oscar/types.go create mode 100644 state/cookie.go diff --git a/.mockery.yaml b/.mockery.yaml index 3820f1e2..c237bb18 100644 --- a/.mockery.yaml +++ b/.mockery.yaml @@ -17,6 +17,9 @@ packages: OnlineNotifier: config: filename: "mock_online_notifier_test.go" + CookieCracker: + config: + filename: "mock_cookie_cracker_test.go" github.com/mk6i/retro-aim-server/server/http: interfaces: UserManager: @@ -97,4 +100,7 @@ packages: filename: "mock_bart_manager_test.go" LegacyBuddyListManager: config: - filename: "mock_legacy_buddy_list_manager_test.go" \ No newline at end of file + filename: "mock_legacy_buddy_list_manager_test.go" + CookieIssuer: + config: + filename: "mock_cookie_issuer_test.go" diff --git a/cmd/server/main.go b/cmd/server/main.go index 892913d4..c19174d7 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -3,6 +3,7 @@ package main import ( "fmt" "log/slog" + "net" "os" "sync" @@ -31,6 +32,12 @@ func main() { os.Exit(1) } + cookieBaker, err := state.NewHMACCookieBaker() + if err != nil { + _, _ = fmt.Fprintf(os.Stderr, "unable to create HMAC cookie baker: %s\n", err.Error()) + os.Exit(1) + } + logger := middleware.NewLogger(cfg) sessionManager := state.NewInMemorySessionManager(logger) chatRegistry := state.NewChatRegistry() @@ -45,10 +52,10 @@ func main() { }() go func(logger *slog.Logger) { logger = logger.With("svc", "BOS") - authService := foodgroup.NewAuthService(cfg, sessionManager, sessionManager, feedbagStore, feedbagStore, chatRegistry, adjListBuddyListStore) + authService := foodgroup.NewAuthService(cfg, sessionManager, sessionManager, feedbagStore, feedbagStore, chatRegistry, adjListBuddyListStore, cookieBaker) bartService := foodgroup.NewBARTService(logger, feedbagStore, sessionManager, feedbagStore, adjListBuddyListStore) buddyService := foodgroup.NewBuddyService(sessionManager, feedbagStore, adjListBuddyListStore) - oServiceService := foodgroup.NewOServiceService(cfg, sessionManager, feedbagStore, adjListBuddyListStore, logger) + oServiceService := foodgroup.NewOServiceService(cfg, sessionManager, feedbagStore, adjListBuddyListStore, logger, cookieBaker) oServiceServiceForBOS := foodgroup.NewOServiceServiceForBOS(*oServiceService, chatRegistry) locateService := foodgroup.NewLocateService(sessionManager, feedbagStore, feedbagStore, adjListBuddyListStore) newChatSessMgr := func() foodgroup.SessionManager { return state.NewInMemorySessionManager(logger) } @@ -71,15 +78,17 @@ func main() { OServiceBOSHandler: handler.NewOServiceHandlerForBOS(logger, oServiceService, oServiceServiceForBOS), PermitDenyHandler: handler.NewPermitDenyHandler(logger, foodgroupService), }), + CookieCracker: cookieBaker, Logger: logger, OnlineNotifier: oServiceServiceForBOS, + ListenAddr: net.JoinHostPort("", cfg.BOSPort), }.Start() wg.Done() }(logger) go func(logger *slog.Logger) { logger = logger.With("svc", "CHAT") - authService := foodgroup.NewAuthService(cfg, sessionManager, sessionManager, feedbagStore, feedbagStore, chatRegistry, adjListBuddyListStore) - oServiceService := foodgroup.NewOServiceService(cfg, sessionManager, feedbagStore, adjListBuddyListStore, logger) + authService := foodgroup.NewAuthService(cfg, sessionManager, sessionManager, feedbagStore, feedbagStore, chatRegistry, adjListBuddyListStore, cookieBaker) + oServiceService := foodgroup.NewOServiceService(cfg, sessionManager, feedbagStore, adjListBuddyListStore, logger, cookieBaker) chatService := foodgroup.NewChatService(chatRegistry) oServiceServiceForChat := foodgroup.NewOServiceServiceForChat(*oServiceService, chatRegistry) @@ -92,18 +101,20 @@ func main() { }), Logger: logger, OnlineNotifier: oServiceServiceForChat, + CookieCracker: cookieBaker, }.Start() wg.Done() }(logger) go func(logger *slog.Logger) { logger = logger.With("svc", "CHAT_NAV") - authService := foodgroup.NewAuthService(cfg, sessionManager, sessionManager, feedbagStore, feedbagStore, chatRegistry, adjListBuddyListStore) - oServiceService := foodgroup.NewOServiceService(cfg, sessionManager, feedbagStore, adjListBuddyListStore, logger) + sessionManager := state.NewInMemorySessionManager(logger) + authService := foodgroup.NewAuthService(cfg, sessionManager, sessionManager, feedbagStore, feedbagStore, chatRegistry, adjListBuddyListStore, cookieBaker) + oServiceService := foodgroup.NewOServiceService(cfg, sessionManager, feedbagStore, adjListBuddyListStore, logger, cookieBaker) oServiceServiceForChatNav := foodgroup.NewOServiceServiceForChatNav(*oServiceService, chatRegistry) newChatSessMgr := func() foodgroup.SessionManager { return state.NewInMemorySessionManager(logger) } chatNavService := foodgroup.NewChatNavService(logger, chatRegistry, state.NewChatRoom, newChatSessMgr) - oscar.ChatNavServer{ + oscar.BOSServer{ AuthService: authService, Config: cfg, Handler: handler.NewChatNavRouter(handler.Handlers{ @@ -112,35 +123,41 @@ func main() { }), Logger: logger, OnlineNotifier: oServiceServiceForChatNav, + ListenAddr: net.JoinHostPort("", cfg.ChatNavPort), + CookieCracker: cookieBaker, }.Start() wg.Done() }(logger) go func(logger *slog.Logger) { logger = logger.With("svc", "ALERT") - authService := foodgroup.NewAuthService(cfg, sessionManager, sessionManager, feedbagStore, feedbagStore, chatRegistry, adjListBuddyListStore) - oServiceService := foodgroup.NewOServiceService(cfg, sessionManager, feedbagStore, adjListBuddyListStore, logger) + sessionManager := state.NewInMemorySessionManager(logger) + authService := foodgroup.NewAuthService(cfg, sessionManager, sessionManager, feedbagStore, feedbagStore, chatRegistry, adjListBuddyListStore, cookieBaker) + oServiceService := foodgroup.NewOServiceService(cfg, sessionManager, feedbagStore, adjListBuddyListStore, logger, cookieBaker) oServiceServiceForAlert := foodgroup.NewOServiceServiceForAlert(*oServiceService) - oscar.AlertServer{ + oscar.BOSServer{ AuthService: authService, Config: cfg, Handler: handler.NewAlertRouter(handler.Handlers{ AlertHandler: handler.NewAlertHandler(logger), OServiceAlertHandler: handler.NewOServiceHandlerForAlert(logger, oServiceService, oServiceServiceForAlert), }), + CookieCracker: cookieBaker, Logger: logger, OnlineNotifier: oServiceServiceForAlert, + ListenAddr: net.JoinHostPort("", cfg.AlertPort), }.Start() wg.Done() }(logger) go func(logger *slog.Logger) { logger = logger.With("svc", "AUTH") - authHandler := foodgroup.NewAuthService(cfg, sessionManager, nil, feedbagStore, feedbagStore, chatRegistry, adjListBuddyListStore) + authHandler := foodgroup.NewAuthService(cfg, sessionManager, nil, feedbagStore, feedbagStore, chatRegistry, adjListBuddyListStore, cookieBaker) oscar.AuthServer{ - AuthService: authHandler, - Config: cfg, - Logger: logger, + AuthService: authHandler, + Config: cfg, + Logger: logger, + CookieCracker: cookieBaker, }.Start() wg.Done() }(logger) diff --git a/config/settings.env b/config/settings.env index ecd5b84a..bfa6f480 100644 --- a/config/settings.env +++ b/config/settings.env @@ -34,7 +34,7 @@ export FAIL_FAST=false # Set logging granularity. Possible values: 'trace', 'debug', 'info', 'warn', # 'error'. -export LOG_LEVEL=info +export LOG_LEVEL=debug # The hostname that AIM clients connect to in order to reach OSCAR services # (auth, BOS, BUCP, etc). Make sure the hostname is reachable by all clients. diff --git a/foodgroup/auth.go b/foodgroup/auth.go index fa7f5d9f..30f96fc0 100644 --- a/foodgroup/auth.go +++ b/foodgroup/auth.go @@ -14,17 +14,16 @@ import ( "github.com/google/uuid" ) -// authCookieLen is the fixed auth cookie length. -const authCookieLen = 256 - // NewAuthService creates a new instance of AuthService. -func NewAuthService(cfg config.Config, +func NewAuthService( + cfg config.Config, sessionManager SessionManager, messageRelayer MessageRelayer, feedbagManager FeedbagManager, userManager UserManager, chatRegistry ChatRegistry, legacyBuddyListManager LegacyBuddyListManager, + cookieIssuer CookieIssuer, ) *AuthService { return &AuthService{ chatRegistry: chatRegistry, @@ -34,6 +33,7 @@ func NewAuthService(cfg config.Config, messageRelayer: messageRelayer, sessionManager: sessionManager, userManager: userManager, + cookieIssuer: cookieIssuer, } } @@ -48,26 +48,27 @@ type AuthService struct { messageRelayer MessageRelayer sessionManager SessionManager userManager UserManager + cookieIssuer CookieIssuer } -// RetrieveChatSession returns a chat room session. Return nil if the session -// does not exist. -func (s AuthService) RetrieveChatSession(loginCookie []byte) (*state.Session, error) { +// RegisterChatSession creates and returns a chat room session. +func (s AuthService) RegisterChatSession(loginCookie []byte) (*state.Session, error) { c := chatLoginCookie{} if err := wire.Unmarshal(&c, bytes.NewBuffer(loginCookie)); err != nil { return nil, err } - _, chatSessMgr, err := s.chatRegistry.Retrieve(c.Cookie) + room, chatSessMgr, err := s.chatRegistry.Retrieve(c.ChatCookie) if err != nil { return nil, err } - return chatSessMgr.(SessionManager).RetrieveSession(c.SessID), nil + chatSess := chatSessMgr.(SessionManager).AddSession(c.ScreenName) + chatSess.SetChatRoomCookie(room.Cookie) + return chatSess, nil } -// RetrieveBOSSession returns a user's session. Return nil if the session does -// not exist. -func (s AuthService) RetrieveBOSSession(sessionID string) (*state.Session, error) { - return s.sessionManager.RetrieveSession(sessionID), nil +// RegisterBOSSession creates and returns a user's session. +func (s AuthService) RegisterBOSSession(sessionID string) (*state.Session, error) { + return s.sessionManager.AddSession(sessionID), nil } // Signout removes this user's session and notifies users who have this user on @@ -167,11 +168,10 @@ func (s AuthService) BUCPChallenge( // (wire.LoginTLVTagsErrorSubcode). func (s AuthService) BUCPLogin( bodyIn wire.SNAC_0x17_0x02_BUCPLoginRequest, - newUUIDFn func() uuid.UUID, newUserFn func(screenName string) (state.User, error), ) (wire.SNACMessage, error) { - block, err := s.login(bodyIn.TLVList, newUserFn, newUUIDFn) + block, err := s.login(bodyIn.TLVList, newUserFn) if err != nil { return wire.SNACMessage{}, err } @@ -197,12 +197,8 @@ func (s AuthService) BUCPLogin( // (wire.LoginTLVTagsReconnectHere) and an authorization cookie // (wire.LoginTLVTagsAuthorizationCookie). Else, an error code is set // (wire.LoginTLVTagsErrorSubcode). -func (s AuthService) FLAPLogin( - frame wire.FLAPSignonFrame, - newUUIDFn func() uuid.UUID, - newUserFn func(screenName string) (state.User, error), -) (wire.TLVRestBlock, error) { - return s.login(frame.TLVList, newUserFn, newUUIDFn) +func (s AuthService) FLAPLogin(frame wire.FLAPSignonFrame, newUserFn func(screenName string) (state.User, error)) (wire.TLVRestBlock, error) { + return s.login(frame.TLVList, newUserFn) } // login validates a user's credentials and creates their session. it returns @@ -210,7 +206,6 @@ func (s AuthService) FLAPLogin( func (s AuthService) login( TLVList wire.TLVList, newUserFn func(screenName string) (state.User, error), - newUUIDFn func() uuid.UUID, ) (wire.TLVRestBlock, error) { screenName, found := TLVList.String(wire.LoginTLVTagsScreenName) @@ -249,24 +244,16 @@ func (s AuthService) login( } } - sess := s.sessionManager.AddSession(newUUIDFn().String(), screenName) - - // Some clients (such as perl NET::OSCAR) expect the auth cookie to be - // exactly 256 bytes, even though the cookie is stored in a - // variable-length TLV. Pad the auth cookie to make sure it's exactly - // 256 bytes. - if len(sess.ID()) > authCookieLen { - return wire.TLVRestBlock{}, fmt.Errorf("sess is too long, expect 256 bytes, got %d", len(sess.ID())) + cookie, err := s.cookieIssuer.Issue([]byte(screenName)) + if err != nil { + return wire.TLVRestBlock{}, fmt.Errorf("failed to make auth cookie: %w", err) } - authCookie := make([]byte, authCookieLen) - copy(authCookie, sess.ID()) - // auth success return wire.TLVRestBlock{ TLVList: []wire.TLV{ wire.NewTLV(wire.LoginTLVTagsScreenName, screenName), wire.NewTLV(wire.LoginTLVTagsReconnectHere, net.JoinHostPort(s.config.OSCARHost, s.config.BOSPort)), - wire.NewTLV(wire.LoginTLVTagsAuthorizationCookie, authCookie), + wire.NewTLV(wire.LoginTLVTagsAuthorizationCookie, cookie), }, }, nil } diff --git a/foodgroup/auth_test.go b/foodgroup/auth_test.go index b25957a2..f620d56f 100644 --- a/foodgroup/auth_test.go +++ b/foodgroup/auth_test.go @@ -15,13 +15,11 @@ import ( ) func TestAuthService_BUCPLoginRequest(t *testing.T) { - sessUUID := uuid.UUID{1, 2, 3} user := state.User{ ScreenName: "screen_name", AuthKey: "auth_key", } assert.NoError(t, user.HashPassword("the_password")) - userSession := newTestSession(user.ScreenName, sessOptID(sessUUID.String())) cases := []struct { // name is the unit test name @@ -63,13 +61,10 @@ func TestAuthService_BUCPLoginRequest(t *testing.T) { }, }, }, - sessionManagerParams: sessionManagerParams{ - addSessionParams: addSessionParams{ - { - sessID: userSession.ID(), - screenName: user.ScreenName, - result: userSession, - }, + cookieIssuerParams: cookieIssuerParams{ + { + data: []byte(user.ScreenName), + cookie: []byte("the-cookie"), }, }, }, @@ -83,8 +78,7 @@ func TestAuthService_BUCPLoginRequest(t *testing.T) { TLVList: wire.TLVList{ wire.NewTLV(wire.LoginTLVTagsScreenName, user.ScreenName), wire.NewTLV(wire.LoginTLVTagsReconnectHere, "127.0.0.1:1234"), - wire.NewTLV(wire.LoginTLVTagsAuthorizationCookie, append([]byte(userSession.ID()), - make([]byte, authCookieLen-len([]byte(userSession.ID())))...)), + wire.NewTLV(wire.LoginTLVTagsAuthorizationCookie, []byte("the-cookie")), }, }, }, @@ -119,13 +113,10 @@ func TestAuthService_BUCPLoginRequest(t *testing.T) { }, }, }, - sessionManagerParams: sessionManagerParams{ - addSessionParams: addSessionParams{ - { - sessID: userSession.ID(), - screenName: user.ScreenName, - result: userSession, - }, + cookieIssuerParams: cookieIssuerParams{ + { + data: []byte(user.ScreenName), + cookie: []byte("the-cookie"), }, }, }, @@ -142,8 +133,7 @@ func TestAuthService_BUCPLoginRequest(t *testing.T) { TLVList: wire.TLVList{ wire.NewTLV(wire.LoginTLVTagsScreenName, user.ScreenName), wire.NewTLV(wire.LoginTLVTagsReconnectHere, "127.0.0.1:1234"), - wire.NewTLV(wire.LoginTLVTagsAuthorizationCookie, append([]byte(userSession.ID()), - make([]byte, authCookieLen-len([]byte(userSession.ID())))...)), + wire.NewTLV(wire.LoginTLVTagsAuthorizationCookie, []byte("the-cookie")), }, }, }, @@ -178,13 +168,10 @@ func TestAuthService_BUCPLoginRequest(t *testing.T) { }, }, }, - sessionManagerParams: sessionManagerParams{ - addSessionParams: addSessionParams{ - { - sessID: userSession.ID(), - screenName: user.ScreenName, - result: userSession, - }, + cookieIssuerParams: cookieIssuerParams{ + { + data: []byte(user.ScreenName), + cookie: []byte("the-cookie"), }, }, }, @@ -201,8 +188,7 @@ func TestAuthService_BUCPLoginRequest(t *testing.T) { TLVList: wire.TLVList{ wire.NewTLV(wire.LoginTLVTagsScreenName, user.ScreenName), wire.NewTLV(wire.LoginTLVTagsReconnectHere, "127.0.0.1:1234"), - wire.NewTLV(wire.LoginTLVTagsAuthorizationCookie, append([]byte(userSession.ID()), - make([]byte, authCookieLen-len([]byte(userSession.ID())))...)), + wire.NewTLV(wire.LoginTLVTagsAuthorizationCookie, []byte("the-cookie")), }, }, }, @@ -238,13 +224,10 @@ func TestAuthService_BUCPLoginRequest(t *testing.T) { }, }, }, - sessionManagerParams: sessionManagerParams{ - addSessionParams: addSessionParams{ - { - sessID: userSession.ID(), - screenName: user.ScreenName, - result: userSession, - }, + cookieIssuerParams: cookieIssuerParams{ + { + data: []byte(user.ScreenName), + cookie: []byte("the-cookie"), }, }, }, @@ -261,8 +244,7 @@ func TestAuthService_BUCPLoginRequest(t *testing.T) { TLVList: wire.TLVList{ wire.NewTLV(wire.LoginTLVTagsScreenName, user.ScreenName), wire.NewTLV(wire.LoginTLVTagsReconnectHere, "127.0.0.1:1234"), - wire.NewTLV(wire.LoginTLVTagsAuthorizationCookie, append([]byte(userSession.ID()), - make([]byte, authCookieLen-len([]byte(userSession.ID())))...)), + wire.NewTLV(wire.LoginTLVTagsAuthorizationCookie, []byte("the-cookie")), }, }, }, @@ -407,20 +389,20 @@ func TestAuthService_BUCPLoginRequest(t *testing.T) { Return(params.err) } sessionManager := newMockSessionManager(t) - for _, params := range tc.mockParams.addSessionParams { - sessionManager.EXPECT(). - AddSession(params.sessID, params.screenName). - Return(params.result) + cookieIssuer := newMockCookieIssuer(t) + for _, params := range tc.mockParams.cookieIssuerParams { + cookieIssuer.EXPECT(). + Issue(params.data). + Return(params.cookie, params.err) } + svc := AuthService{ config: tc.cfg, + cookieIssuer: cookieIssuer, sessionManager: sessionManager, userManager: userManager, } - fnNewUUID := func() uuid.UUID { - return sessUUID - } - outputSNAC, err := svc.BUCPLogin(tc.inputSNAC, fnNewUUID, tc.newUserFn) + outputSNAC, err := svc.BUCPLogin(tc.inputSNAC, tc.newUserFn) assert.ErrorIs(t, err, tc.wantErr) assert.Equal(t, tc.expectOutput, outputSNAC) }) @@ -428,13 +410,11 @@ func TestAuthService_BUCPLoginRequest(t *testing.T) { } func TestAuthService_FLAPLoginResponse(t *testing.T) { - sessUUID := uuid.UUID{1, 2, 3} user := state.User{ ScreenName: "screen_name", AuthKey: "auth_key", } assert.NoError(t, user.HashPassword("the_password")) - userSession := newTestSession(user.ScreenName, sessOptID(sessUUID.String())) // obfuscated password value: "the_password" roastedPassword := []byte{0x87, 0x4E, 0xE4, 0x9B, 0x49, 0xE7, 0xA8, 0xE1, 0x06, 0xCC, 0xCB, 0x82} @@ -479,13 +459,10 @@ func TestAuthService_FLAPLoginResponse(t *testing.T) { }, }, }, - sessionManagerParams: sessionManagerParams{ - addSessionParams: addSessionParams{ - { - sessID: userSession.ID(), - screenName: user.ScreenName, - result: userSession, - }, + cookieIssuerParams: cookieIssuerParams{ + { + data: []byte(user.ScreenName), + cookie: []byte("the-cookie"), }, }, }, @@ -493,8 +470,7 @@ func TestAuthService_FLAPLoginResponse(t *testing.T) { TLVList: wire.TLVList{ wire.NewTLV(wire.LoginTLVTagsScreenName, user.ScreenName), wire.NewTLV(wire.LoginTLVTagsReconnectHere, "127.0.0.1:1234"), - wire.NewTLV(wire.LoginTLVTagsAuthorizationCookie, append([]byte(userSession.ID()), - make([]byte, authCookieLen-len([]byte(userSession.ID())))...)), + wire.NewTLV(wire.LoginTLVTagsAuthorizationCookie, []byte("the-cookie")), }, }, }, @@ -527,13 +503,10 @@ func TestAuthService_FLAPLoginResponse(t *testing.T) { }, }, }, - sessionManagerParams: sessionManagerParams{ - addSessionParams: addSessionParams{ - { - sessID: userSession.ID(), - screenName: user.ScreenName, - result: userSession, - }, + cookieIssuerParams: cookieIssuerParams{ + { + data: []byte(user.ScreenName), + cookie: []byte("the-cookie"), }, }, }, @@ -544,8 +517,7 @@ func TestAuthService_FLAPLoginResponse(t *testing.T) { TLVList: wire.TLVList{ wire.NewTLV(wire.LoginTLVTagsScreenName, user.ScreenName), wire.NewTLV(wire.LoginTLVTagsReconnectHere, "127.0.0.1:1234"), - wire.NewTLV(wire.LoginTLVTagsAuthorizationCookie, append([]byte(userSession.ID()), - make([]byte, authCookieLen-len([]byte(userSession.ID())))...)), + wire.NewTLV(wire.LoginTLVTagsAuthorizationCookie, []byte("the-cookie")), }, }, }, @@ -578,13 +550,10 @@ func TestAuthService_FLAPLoginResponse(t *testing.T) { }, }, }, - sessionManagerParams: sessionManagerParams{ - addSessionParams: addSessionParams{ - { - sessID: userSession.ID(), - screenName: user.ScreenName, - result: userSession, - }, + cookieIssuerParams: cookieIssuerParams{ + { + data: []byte(user.ScreenName), + cookie: []byte("the-cookie"), }, }, }, @@ -595,8 +564,7 @@ func TestAuthService_FLAPLoginResponse(t *testing.T) { TLVList: wire.TLVList{ wire.NewTLV(wire.LoginTLVTagsScreenName, user.ScreenName), wire.NewTLV(wire.LoginTLVTagsReconnectHere, "127.0.0.1:1234"), - wire.NewTLV(wire.LoginTLVTagsAuthorizationCookie, append([]byte(userSession.ID()), - make([]byte, authCookieLen-len([]byte(userSession.ID())))...)), + wire.NewTLV(wire.LoginTLVTagsAuthorizationCookie, []byte("the-cookie")), }, }, }, @@ -630,13 +598,10 @@ func TestAuthService_FLAPLoginResponse(t *testing.T) { }, }, }, - sessionManagerParams: sessionManagerParams{ - addSessionParams: addSessionParams{ - { - sessID: userSession.ID(), - screenName: user.ScreenName, - result: userSession, - }, + cookieIssuerParams: cookieIssuerParams{ + { + data: []byte(user.ScreenName), + cookie: []byte("the-cookie"), }, }, }, @@ -647,8 +612,7 @@ func TestAuthService_FLAPLoginResponse(t *testing.T) { TLVList: wire.TLVList{ wire.NewTLV(wire.LoginTLVTagsScreenName, user.ScreenName), wire.NewTLV(wire.LoginTLVTagsReconnectHere, "127.0.0.1:1234"), - wire.NewTLV(wire.LoginTLVTagsAuthorizationCookie, append([]byte(userSession.ID()), - make([]byte, authCookieLen-len([]byte(userSession.ID())))...)), + wire.NewTLV(wire.LoginTLVTagsAuthorizationCookie, []byte("the-cookie")), }, }, }, @@ -783,20 +747,19 @@ func TestAuthService_FLAPLoginResponse(t *testing.T) { Return(params.err) } sessionManager := newMockSessionManager(t) - for _, params := range tc.mockParams.addSessionParams { - sessionManager.EXPECT(). - AddSession(params.sessID, params.screenName). - Return(params.result) + cookieIssuer := newMockCookieIssuer(t) + for _, params := range tc.mockParams.cookieIssuerParams { + cookieIssuer.EXPECT(). + Issue(params.data). + Return(params.cookie, params.err) } svc := AuthService{ config: tc.cfg, + cookieIssuer: cookieIssuer, sessionManager: sessionManager, userManager: userManager, } - fnNewUUID := func() uuid.UUID { - return sessUUID - } - outputSNAC, err := svc.FLAPLogin(tc.inputSNAC, fnNewUUID, tc.newUserFn) + outputSNAC, err := svc.FLAPLogin(tc.inputSNAC, tc.newUserFn) assert.ErrorIs(t, err, tc.wantErr) assert.Equal(t, tc.expectOutput, outputSNAC) }) @@ -972,20 +935,20 @@ func TestAuthService_BUCPChallengeRequest(t *testing.T) { } } -func TestAuthService_RetrieveChatSession_HappyPath(t *testing.T) { +func TestAuthService_RegisterChatSession_HappyPath(t *testing.T) { cookie := "chat-1234" - sess := newTestSession("screen-name", sessOptCannedID) + sess := newTestSession("screen-name") c := chatLoginCookie{ - Cookie: cookie, - SessID: sess.ID(), + ChatCookie: cookie, + ScreenName: sess.ScreenName(), } buf := &bytes.Buffer{} assert.NoError(t, wire.Marshal(c, buf)) sessionManager := newMockSessionManager(t) sessionManager.EXPECT(). - RetrieveSession(sess.ID()). + AddSession(sess.ScreenName()). Return(sess) chatRegistry := newMockChatRegistry(t) @@ -993,20 +956,22 @@ func TestAuthService_RetrieveChatSession_HappyPath(t *testing.T) { Retrieve(cookie). Return(state.ChatRoom{}, sessionManager, nil) - svc := NewAuthService(config.Config{}, nil, nil, nil, nil, chatRegistry, nil) + cookieIssuer := newMockCookieIssuer(t) + + svc := NewAuthService(config.Config{}, nil, nil, nil, nil, chatRegistry, nil, cookieIssuer) - have, err := svc.RetrieveChatSession(buf.Bytes()) + have, err := svc.RegisterChatSession(buf.Bytes()) assert.NoError(t, err) assert.Equal(t, sess, have) } -func TestAuthService_RetrieveChatSession_ChatNotFound(t *testing.T) { +func TestAuthService_RegisterBOSSession_ChatNotFound(t *testing.T) { cookie := "chat-1234" - sess := newTestSession("screen-name", sessOptCannedID) + sess := newTestSession("screen-name") c := chatLoginCookie{ - Cookie: cookie, - SessID: sess.ID(), + ChatCookie: cookie, + ScreenName: sess.ScreenName(), } loginCookie := &bytes.Buffer{} assert.NoError(t, wire.Marshal(c, loginCookie)) @@ -1016,66 +981,43 @@ func TestAuthService_RetrieveChatSession_ChatNotFound(t *testing.T) { Retrieve(cookie). Return(state.ChatRoom{}, nil, state.ErrChatRoomNotFound) - svc := NewAuthService(config.Config{}, nil, nil, nil, nil, chatRegistry, nil) + cookieIssuer := newMockCookieIssuer(t) + svc := NewAuthService(config.Config{}, nil, nil, nil, nil, chatRegistry, nil, cookieIssuer) - _, err := svc.RetrieveChatSession(loginCookie.Bytes()) + _, err := svc.RegisterChatSession(loginCookie.Bytes()) assert.ErrorIs(t, err, state.ErrChatRoomNotFound) } -func TestAuthService_RetrieveChatSession_SessionNotFound(t *testing.T) { - cookie := "chat-1234" - sess := newTestSession("screen-name", sessOptCannedID) - - c := chatLoginCookie{ - Cookie: cookie, - SessID: sess.ID(), - } - buf := &bytes.Buffer{} - assert.NoError(t, wire.Marshal(c, buf)) - - sessionManager := newMockSessionManager(t) - sessionManager.EXPECT(). - RetrieveSession(sess.ID()). - Return(nil) - - chatRegistry := newMockChatRegistry(t) - chatRegistry.EXPECT(). - Retrieve(cookie). - Return(state.ChatRoom{}, sessionManager, nil) - - svc := NewAuthService(config.Config{}, nil, nil, nil, nil, chatRegistry, nil) - - have, err := svc.RetrieveChatSession(buf.Bytes()) - assert.NoError(t, err) - assert.Nil(t, have) -} - -func TestAuthService_RetrieveBOSSession_HappyPath(t *testing.T) { - sess := newTestSession("screen-name", sessOptCannedID) +func TestAuthService_RegisterBOSSession_HappyPath(t *testing.T) { + sess := newTestSession("screen-name") sessionManager := newMockSessionManager(t) sessionManager.EXPECT(). - RetrieveSession(sess.ID()). + AddSession(sess.ScreenName()). Return(sess) - svc := NewAuthService(config.Config{}, sessionManager, nil, nil, nil, nil, nil) + cookieIssuer := newMockCookieIssuer(t) - have, err := svc.RetrieveBOSSession(sess.ID()) + svc := NewAuthService(config.Config{}, sessionManager, nil, nil, nil, nil, nil, cookieIssuer) + + have, err := svc.RegisterBOSSession(sess.ScreenName()) assert.NoError(t, err) assert.Equal(t, sess, have) } -func TestAuthService_RetrieveBOSSession_SessionNotFound(t *testing.T) { - sess := newTestSession("screen-name", sessOptCannedID) +func TestAuthService_RegisterBOSSession_SessionNotFound(t *testing.T) { + sess := newTestSession("screen-name") sessionManager := newMockSessionManager(t) sessionManager.EXPECT(). - RetrieveSession(sess.ID()). + AddSession(sess.ScreenName()). Return(nil) - svc := NewAuthService(config.Config{}, sessionManager, nil, nil, nil, nil, nil) + cookieIssuer := newMockCookieIssuer(t) + + svc := NewAuthService(config.Config{}, sessionManager, nil, nil, nil, nil, nil, cookieIssuer) - have, err := svc.RetrieveBOSSession(sess.ID()) + have, err := svc.RegisterBOSSession(sess.ScreenName()) assert.NoError(t, err) assert.Nil(t, have) } @@ -1213,7 +1155,9 @@ func TestAuthService_SignoutChat(t *testing.T) { Retrieve(tt.chatRoom.Cookie). Return(tt.chatRoom, chatSessionManager, tt.wantErr) - svc := NewAuthService(config.Config{}, nil, nil, nil, nil, chatRegistry, nil) + cookieIssuer := newMockCookieIssuer(t) + + svc := NewAuthService(config.Config{}, nil, nil, nil, nil, chatRegistry, nil, cookieIssuer) err := svc.SignoutChat(nil, tt.userSession) assert.ErrorIs(t, err, tt.wantErr) @@ -1339,7 +1283,9 @@ func TestAuthService_Signout(t *testing.T) { Return(params.result) } - svc := NewAuthService(config.Config{}, sessionManager, messageRelayer, feedbagManager, nil, nil, legacyBuddyListManager) + cookieIssuer := newMockCookieIssuer(t) + + svc := NewAuthService(config.Config{}, sessionManager, messageRelayer, feedbagManager, nil, nil, legacyBuddyListManager, cookieIssuer) err := svc.Signout(nil, tt.userSession) assert.ErrorIs(t, err, tt.wantErr) diff --git a/foodgroup/chat_nav.go b/foodgroup/chat_nav.go index 479ef546..6ef0523f 100644 --- a/foodgroup/chat_nav.go +++ b/foodgroup/chat_nav.go @@ -87,7 +87,7 @@ func (s ChatNavService) CreateRoom(_ context.Context, sess *state.Session, inFra s.chatRegistry.Register(room, chatSessMgr) // add user to chat room - chatSess := chatSessMgr.AddSession(sess.ID(), sess.ScreenName()) + chatSess := chatSessMgr.AddSession(sess.ScreenName()) chatSess.SetChatRoomCookie(room.Cookie) return wire.SNACMessage{ @@ -115,7 +115,7 @@ func (s ChatNavService) CreateRoom(_ context.Context, sess *state.Session, inFra } // RequestRoomInfo returns wire.ChatNavNavInfo, which contains metadata for -// the chat room specified in the inFrame.Cookie. +// the chat room specified in the inFrame.hmacCookie. func (s ChatNavService) RequestRoomInfo(_ context.Context, inFrame wire.SNACFrame, inBody wire.SNAC_0x0D_0x04_ChatNavRequestRoomInfo) (wire.SNACMessage, error) { room, _, err := s.chatRegistry.Retrieve(inBody.Cookie) if err != nil { diff --git a/foodgroup/chat_nav_test.go b/foodgroup/chat_nav_test.go index 7b2d2c60..1618ab7b 100644 --- a/foodgroup/chat_nav_test.go +++ b/foodgroup/chat_nav_test.go @@ -12,13 +12,13 @@ import ( ) func TestChatNavService_CreateRoom(t *testing.T) { - bosSess := newTestSession("user-screen-name", sessOptCannedID) + bosSess := newTestSession("user-screen-name") chatSess := &state.Session{} chatRegistry := state.NewChatRegistry() sessionManager := newMockSessionManager(t) - sessionManager.EXPECT().AddSession(bosSess.ID(), bosSess.ScreenName()). + sessionManager.EXPECT().AddSession(bosSess.ScreenName()). Return(chatSess) newChatRoom := func() state.ChatRoom { diff --git a/foodgroup/mock_cookie_issuer_test.go b/foodgroup/mock_cookie_issuer_test.go new file mode 100644 index 00000000..a02e6490 --- /dev/null +++ b/foodgroup/mock_cookie_issuer_test.go @@ -0,0 +1,90 @@ +// Code generated by mockery v2.40.1. DO NOT EDIT. + +package foodgroup + +import mock "github.com/stretchr/testify/mock" + +// mockCookieIssuer is an autogenerated mock type for the CookieIssuer type +type mockCookieIssuer struct { + mock.Mock +} + +type mockCookieIssuer_Expecter struct { + mock *mock.Mock +} + +func (_m *mockCookieIssuer) EXPECT() *mockCookieIssuer_Expecter { + return &mockCookieIssuer_Expecter{mock: &_m.Mock} +} + +// Issue provides a mock function with given fields: data +func (_m *mockCookieIssuer) Issue(data []byte) ([]byte, error) { + ret := _m.Called(data) + + if len(ret) == 0 { + panic("no return value specified for Issue") + } + + var r0 []byte + var r1 error + if rf, ok := ret.Get(0).(func([]byte) ([]byte, error)); ok { + return rf(data) + } + if rf, ok := ret.Get(0).(func([]byte) []byte); ok { + r0 = rf(data) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]byte) + } + } + + if rf, ok := ret.Get(1).(func([]byte) error); ok { + r1 = rf(data) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// mockCookieIssuer_Issue_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Issue' +type mockCookieIssuer_Issue_Call struct { + *mock.Call +} + +// Issue is a helper method to define mock.On call +// - data []byte +func (_e *mockCookieIssuer_Expecter) Issue(data interface{}) *mockCookieIssuer_Issue_Call { + return &mockCookieIssuer_Issue_Call{Call: _e.mock.On("Issue", data)} +} + +func (_c *mockCookieIssuer_Issue_Call) Run(run func(data []byte)) *mockCookieIssuer_Issue_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].([]byte)) + }) + return _c +} + +func (_c *mockCookieIssuer_Issue_Call) Return(_a0 []byte, _a1 error) *mockCookieIssuer_Issue_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *mockCookieIssuer_Issue_Call) RunAndReturn(run func([]byte) ([]byte, error)) *mockCookieIssuer_Issue_Call { + _c.Call.Return(run) + return _c +} + +// newMockCookieIssuer creates a new instance of mockCookieIssuer. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func newMockCookieIssuer(t interface { + mock.TestingT + Cleanup(func()) +}) *mockCookieIssuer { + mock := &mockCookieIssuer{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/foodgroup/mock_session_manager_test.go b/foodgroup/mock_session_manager_test.go index b4e21b44..ba7f325f 100644 --- a/foodgroup/mock_session_manager_test.go +++ b/foodgroup/mock_session_manager_test.go @@ -20,17 +20,17 @@ func (_m *mockSessionManager) EXPECT() *mockSessionManager_Expecter { return &mockSessionManager_Expecter{mock: &_m.Mock} } -// AddSession provides a mock function with given fields: sessID, screenName -func (_m *mockSessionManager) AddSession(sessID string, screenName string) *state.Session { - ret := _m.Called(sessID, screenName) +// AddSession provides a mock function with given fields: screenName +func (_m *mockSessionManager) AddSession(screenName string) *state.Session { + ret := _m.Called(screenName) if len(ret) == 0 { panic("no return value specified for AddSession") } var r0 *state.Session - if rf, ok := ret.Get(0).(func(string, string) *state.Session); ok { - r0 = rf(sessID, screenName) + if rf, ok := ret.Get(0).(func(string) *state.Session); ok { + r0 = rf(screenName) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*state.Session) @@ -46,15 +46,14 @@ type mockSessionManager_AddSession_Call struct { } // AddSession is a helper method to define mock.On call -// - sessID string // - screenName string -func (_e *mockSessionManager_Expecter) AddSession(sessID interface{}, screenName interface{}) *mockSessionManager_AddSession_Call { - return &mockSessionManager_AddSession_Call{Call: _e.mock.On("AddSession", sessID, screenName)} +func (_e *mockSessionManager_Expecter) AddSession(screenName interface{}) *mockSessionManager_AddSession_Call { + return &mockSessionManager_AddSession_Call{Call: _e.mock.On("AddSession", screenName)} } -func (_c *mockSessionManager_AddSession_Call) Run(run func(sessID string, screenName string)) *mockSessionManager_AddSession_Call { +func (_c *mockSessionManager_AddSession_Call) Run(run func(screenName string)) *mockSessionManager_AddSession_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(string), args[1].(string)) + run(args[0].(string)) }) return _c } @@ -64,7 +63,7 @@ func (_c *mockSessionManager_AddSession_Call) Return(_a0 *state.Session) *mockSe return _c } -func (_c *mockSessionManager_AddSession_Call) RunAndReturn(run func(string, string) *state.Session) *mockSessionManager_AddSession_Call { +func (_c *mockSessionManager_AddSession_Call) RunAndReturn(run func(string) *state.Session) *mockSessionManager_AddSession_Call { _c.Call.Return(run) return _c } diff --git a/foodgroup/oservice.go b/foodgroup/oservice.go index 5f66c57b..9dbcb473 100644 --- a/foodgroup/oservice.go +++ b/foodgroup/oservice.go @@ -21,6 +21,7 @@ func NewOServiceService( feedbagManager FeedbagManager, legacyBuddyListManager LegacyBuddyListManager, logger *slog.Logger, + cookieIssuer CookieIssuer, ) *OServiceService { return &OServiceService{ cfg: cfg, @@ -28,6 +29,7 @@ func NewOServiceService( legacyBuddyListManager: legacyBuddyListManager, messageRelayer: messageRelayer, logger: logger, + cookieIssuer: cookieIssuer, } } @@ -39,6 +41,7 @@ type OServiceService struct { legacyBuddyListManager LegacyBuddyListManager messageRelayer MessageRelayer logger *slog.Logger + cookieIssuer CookieIssuer } // ClientVersions informs the server what food group versions the client @@ -505,8 +508,8 @@ type OServiceServiceForBOS struct { // chatLoginCookie represents credentials used to authenticate a user chat // session. type chatLoginCookie struct { - Cookie string `len_prefix:"uint8"` - SessID string `len_prefix:"uint16"` + ChatCookie string `len_prefix:"uint8"` + ScreenName string `len_prefix:"uint8"` } // ServiceRequest handles service discovery, providing a host name and metadata @@ -521,6 +524,10 @@ type chatLoginCookie struct { func (s OServiceServiceForBOS) ServiceRequest(_ context.Context, sess *state.Session, inFrame wire.SNACFrame, inBody wire.SNAC_0x01_0x04_OServiceServiceRequest) (wire.SNACMessage, error) { switch inBody.FoodGroup { case wire.Alert: + cookie, err := s.cookieIssuer.Issue([]byte(sess.ScreenName())) + if err != nil { + return wire.SNACMessage{}, err + } return wire.SNACMessage{ Frame: wire.SNACFrame{ FoodGroup: wire.OService, @@ -531,7 +538,7 @@ func (s OServiceServiceForBOS) ServiceRequest(_ context.Context, sess *state.Ses TLVRestBlock: wire.TLVRestBlock{ TLVList: wire.TLVList{ wire.NewTLV(wire.OServiceTLVTagsReconnectHere, net.JoinHostPort(s.cfg.OSCARHost, s.cfg.AlertPort)), - wire.NewTLV(wire.OServiceTLVTagsLoginCookie, sess.ID()), + wire.NewTLV(wire.OServiceTLVTagsLoginCookie, cookie), wire.NewTLV(wire.OServiceTLVTagsGroupID, wire.Alert), wire.NewTLV(wire.OServiceTLVTagsSSLCertName, ""), wire.NewTLV(wire.OServiceTLVTagsSSLState, uint8(0x00)), @@ -540,6 +547,10 @@ func (s OServiceServiceForBOS) ServiceRequest(_ context.Context, sess *state.Ses }, }, nil case wire.ChatNav: + cookie, err := s.cookieIssuer.Issue([]byte(sess.ScreenName())) + if err != nil { + return wire.SNACMessage{}, err + } return wire.SNACMessage{ Frame: wire.SNACFrame{ FoodGroup: wire.OService, @@ -550,7 +561,7 @@ func (s OServiceServiceForBOS) ServiceRequest(_ context.Context, sess *state.Ses TLVRestBlock: wire.TLVRestBlock{ TLVList: wire.TLVList{ wire.NewTLV(wire.OServiceTLVTagsReconnectHere, net.JoinHostPort(s.cfg.OSCARHost, s.cfg.ChatNavPort)), - wire.NewTLV(wire.OServiceTLVTagsLoginCookie, sess.ID()), + wire.NewTLV(wire.OServiceTLVTagsLoginCookie, cookie), wire.NewTLV(wire.OServiceTLVTagsGroupID, wire.ChatNav), wire.NewTLV(wire.OServiceTLVTagsSSLCertName, ""), wire.NewTLV(wire.OServiceTLVTagsSSLState, uint8(0x00)), @@ -569,12 +580,23 @@ func (s OServiceServiceForBOS) ServiceRequest(_ context.Context, sess *state.Ses return wire.SNACMessage{}, err } - room, chatSessMgr, err := s.chatRegistry.Retrieve(roomSNAC.Cookie) + room, _, err := s.chatRegistry.Retrieve(roomSNAC.Cookie) + if err != nil { + return wire.SNACMessage{}, fmt.Errorf("unable to retrieve room info: %w", err) + } + + loginCookie := chatLoginCookie{ + ChatCookie: room.Cookie, + ScreenName: sess.ScreenName(), + } + buf := &bytes.Buffer{} + if err := wire.Marshal(loginCookie, buf); err != nil { + return wire.SNACMessage{}, err + } + cookie, err := s.cookieIssuer.Issue(buf.Bytes()) if err != nil { return wire.SNACMessage{}, err } - chatSess := chatSessMgr.(SessionManager).AddSession(sess.ID(), sess.ScreenName()) - chatSess.SetChatRoomCookie(room.Cookie) return wire.SNACMessage{ Frame: wire.SNACFrame{ @@ -586,10 +608,7 @@ func (s OServiceServiceForBOS) ServiceRequest(_ context.Context, sess *state.Ses TLVRestBlock: wire.TLVRestBlock{ TLVList: wire.TLVList{ wire.NewTLV(wire.OServiceTLVTagsReconnectHere, net.JoinHostPort(s.cfg.OSCARHost, s.cfg.ChatPort)), - wire.NewTLV(wire.OServiceTLVTagsLoginCookie, chatLoginCookie{ - Cookie: room.Cookie, - SessID: sess.ID(), - }), + wire.NewTLV(wire.OServiceTLVTagsLoginCookie, cookie), wire.NewTLV(wire.OServiceTLVTagsGroupID, wire.Chat), wire.NewTLV(wire.OServiceTLVTagsSSLCertName, ""), wire.NewTLV(wire.OServiceTLVTagsSSLState, uint8(0x00)), diff --git a/foodgroup/oservice_test.go b/foodgroup/oservice_test.go index d79ef5db..41b5d044 100644 --- a/foodgroup/oservice_test.go +++ b/foodgroup/oservice_test.go @@ -32,6 +32,9 @@ func TestOServiceServiceForBOS_ServiceRequest(t *testing.T) { // expectSNACFrame is the SNAC frame sent from the server to the recipient // client expectOutput wire.SNACMessage + // mockParams is the list of params sent to mocks that satisfy this + // method's dependencies + mockParams mockParams // expectErr is the expected error returned by the router expectErr error }{ @@ -54,7 +57,7 @@ func TestOServiceServiceForBOS_ServiceRequest(t *testing.T) { OSCARHost: "127.0.0.1", ChatNavPort: "1234", }, - userSession: newTestSession("user_screen_name", sessOptCannedID), + userSession: newTestSession("user_screen_name"), inputSNAC: wire.SNACMessage{ Frame: wire.SNACFrame{ RequestID: 1234, @@ -73,7 +76,7 @@ func TestOServiceServiceForBOS_ServiceRequest(t *testing.T) { TLVRestBlock: wire.TLVRestBlock{ TLVList: wire.TLVList{ wire.NewTLV(wire.OServiceTLVTagsReconnectHere, "127.0.0.1:1234"), - wire.NewTLV(wire.OServiceTLVTagsLoginCookie, newTestSession("user_screen_name", sessOptCannedID).ID()), + wire.NewTLV(wire.OServiceTLVTagsLoginCookie, []byte("the-cookie")), wire.NewTLV(wire.OServiceTLVTagsGroupID, wire.ChatNav), wire.NewTLV(wire.OServiceTLVTagsSSLCertName, ""), wire.NewTLV(wire.OServiceTLVTagsSSLState, uint8(0x00)), @@ -81,6 +84,14 @@ func TestOServiceServiceForBOS_ServiceRequest(t *testing.T) { }, }, }, + mockParams: mockParams{ + cookieIssuerParams: cookieIssuerParams{ + { + data: []byte("user_screen_name"), + cookie: []byte("the-cookie"), + }, + }, + }, }, { name: "request info for connecting to alert svc, return alert svc connection metadata", @@ -88,7 +99,7 @@ func TestOServiceServiceForBOS_ServiceRequest(t *testing.T) { OSCARHost: "127.0.0.1", AlertPort: "1234", }, - userSession: newTestSession("user_screen_name", sessOptCannedID), + userSession: newTestSession("user_screen_name"), inputSNAC: wire.SNACMessage{ Frame: wire.SNACFrame{ RequestID: 1234, @@ -107,7 +118,7 @@ func TestOServiceServiceForBOS_ServiceRequest(t *testing.T) { TLVRestBlock: wire.TLVRestBlock{ TLVList: wire.TLVList{ wire.NewTLV(wire.OServiceTLVTagsReconnectHere, "127.0.0.1:1234"), - wire.NewTLV(wire.OServiceTLVTagsLoginCookie, newTestSession("user_screen_name", sessOptCannedID).ID()), + wire.NewTLV(wire.OServiceTLVTagsLoginCookie, []byte("the-cookie")), wire.NewTLV(wire.OServiceTLVTagsGroupID, wire.Alert), wire.NewTLV(wire.OServiceTLVTagsSSLCertName, ""), wire.NewTLV(wire.OServiceTLVTagsSSLState, uint8(0x00)), @@ -115,6 +126,14 @@ func TestOServiceServiceForBOS_ServiceRequest(t *testing.T) { }, }, }, + mockParams: mockParams{ + cookieIssuerParams: cookieIssuerParams{ + { + data: []byte("user_screen_name"), + cookie: []byte("the-cookie"), + }, + }, + }, }, { name: "request info for connecting to chat room, return chat service and chat room metadata", @@ -130,7 +149,7 @@ func TestOServiceServiceForBOS_ServiceRequest(t *testing.T) { InstanceNumber: 16, Name: "my new chat", }, - userSession: newTestSession("user_screen_name", sessOptCannedID), + userSession: newTestSession("user_screen_name"), inputSNAC: wire.SNACMessage{ Frame: wire.SNACFrame{ RequestID: 1234, @@ -158,10 +177,7 @@ func TestOServiceServiceForBOS_ServiceRequest(t *testing.T) { TLVRestBlock: wire.TLVRestBlock{ TLVList: wire.TLVList{ wire.NewTLV(wire.OServiceTLVTagsReconnectHere, "127.0.0.1:1234"), - wire.NewTLV(wire.OServiceTLVTagsLoginCookie, chatLoginCookie{ - Cookie: "the-chat-cookie", - SessID: "user-session-id", - }), + wire.NewTLV(wire.OServiceTLVTagsLoginCookie, []byte("the-cookie")), wire.NewTLV(wire.OServiceTLVTagsGroupID, wire.Chat), wire.NewTLV(wire.OServiceTLVTagsSSLCertName, ""), wire.NewTLV(wire.OServiceTLVTagsSSLState, uint8(0x00)), @@ -169,6 +185,17 @@ func TestOServiceServiceForBOS_ServiceRequest(t *testing.T) { }, }, }, + mockParams: mockParams{ + cookieIssuerParams: cookieIssuerParams{ + { + data: []byte{ + 0x0F, 't', 'h', 'e', '-', 'c', 'h', 'a', 't', '-', 'c', 'o', 'o', 'k', 'i', 'e', + 0x10, 'u', 's', 'e', 'r', '_', 's', 'c', 'r', 'e', 'e', 'n', '_', 'n', 'a', 'm', 'e', + }, + cookie: []byte("the-cookie"), + }, + }, + }, }, { name: "request info for connecting to non-existent chat room, return ErrChatRoomNotFound", @@ -177,7 +204,7 @@ func TestOServiceServiceForBOS_ServiceRequest(t *testing.T) { ChatPort: "1234", }, chatRoom: nil, - userSession: newTestSession("user_screen_name", sessOptCannedID), + userSession: newTestSession("user_screen_name"), inputSNAC: wire.SNACMessage{ Frame: wire.SNACFrame{ RequestID: 1234, @@ -209,16 +236,23 @@ func TestOServiceServiceForBOS_ServiceRequest(t *testing.T) { chatSess := &state.Session{} if tc.chatRoom != nil { sessionManager.EXPECT(). - AddSession(tc.userSession.ID(), tc.userSession.ScreenName()). + AddSession(tc.userSession.ScreenName()). Return(chatSess). Maybe() chatRegistry.Register(*tc.chatRoom, sessionManager) } + cookieIssuer := newMockCookieIssuer(t) + for _, params := range tc.mockParams.cookieIssuerParams { + cookieIssuer.EXPECT(). + Issue(params.data). + Return(params.cookie, params.err) + } // // send input SNAC // svc := NewOServiceServiceForBOS(OServiceService{ - cfg: tc.cfg, + cfg: tc.cfg, + cookieIssuer: cookieIssuer, }, chatRegistry) outputSNAC, err := svc.ServiceRequest(nil, tc.userSession, tc.inputSNAC.Frame, @@ -227,10 +261,6 @@ func TestOServiceServiceForBOS_ServiceRequest(t *testing.T) { if tc.expectErr != nil { return } - if tc.chatRoom != nil { - // assert the user session is linked to the chat room - assert.Equal(t, chatSess.ChatRoomCookie(), tc.chatRoom.Cookie) - } // // verify output // @@ -405,10 +435,11 @@ func TestSetUserInfoFields(t *testing.T) { WhoAddedUser(params.userScreenName). Return(params.result) } + cookieIssuer := newMockCookieIssuer(t) // // send input SNAC // - svc := NewOServiceService(config.Config{}, messageRelayer, feedbagManager, legacyBuddyListManager, slog.Default()) + svc := NewOServiceService(config.Config{}, messageRelayer, feedbagManager, legacyBuddyListManager, slog.Default(), cookieIssuer) outputSNAC, err := svc.SetUserInfoFields(nil, tc.userSession, tc.inputSNAC.Frame, tc.inputSNAC.Body.(wire.SNAC_0x01_0x1E_OServiceSetUserInfoFields)) assert.ErrorIs(t, err, tc.expectErr) @@ -424,7 +455,8 @@ func TestSetUserInfoFields(t *testing.T) { } func TestOServiceService_RateParamsQuery(t *testing.T) { - svc := NewOServiceService(config.Config{}, nil, nil, nil, slog.Default()) + cookieIssuer := newMockCookieIssuer(t) + svc := NewOServiceService(config.Config{}, nil, nil, nil, slog.Default(), cookieIssuer) have := svc.RateParamsQuery(nil, wire.SNACFrame{RequestID: 1234}) want := wire.SNACMessage{ @@ -1349,7 +1381,8 @@ func TestOServiceService_RateParamsQuery(t *testing.T) { } func TestOServiceServiceForBOS_OServiceHostOnline(t *testing.T) { - svc := NewOServiceServiceForBOS(*NewOServiceService(config.Config{}, nil, nil, nil, slog.Default()), nil) + cookieIssuer := newMockCookieIssuer(t) + svc := NewOServiceServiceForBOS(*NewOServiceService(config.Config{}, nil, nil, nil, slog.Default(), cookieIssuer), nil) want := wire.SNACMessage{ Frame: wire.SNACFrame{ @@ -1376,7 +1409,8 @@ func TestOServiceServiceForBOS_OServiceHostOnline(t *testing.T) { } func TestOServiceServiceForChat_OServiceHostOnline(t *testing.T) { - svc := NewOServiceServiceForChat(*NewOServiceService(config.Config{}, nil, nil, nil, slog.Default()), nil) + cookieIssuer := newMockCookieIssuer(t) + svc := NewOServiceServiceForChat(*NewOServiceService(config.Config{}, nil, nil, nil, slog.Default(), cookieIssuer), nil) want := wire.SNACMessage{ Frame: wire.SNACFrame{ @@ -1396,7 +1430,8 @@ func TestOServiceServiceForChat_OServiceHostOnline(t *testing.T) { } func TestOServiceService_ClientVersions(t *testing.T) { - svc := NewOServiceService(config.Config{}, nil, nil, nil, slog.Default()) + cookieIssuer := newMockCookieIssuer(t) + svc := NewOServiceService(config.Config{}, nil, nil, nil, slog.Default(), cookieIssuer) want := wire.SNACMessage{ Frame: wire.SNACFrame{ @@ -1419,7 +1454,8 @@ func TestOServiceService_ClientVersions(t *testing.T) { } func TestOServiceService_UserInfoQuery(t *testing.T) { - svc := NewOServiceService(config.Config{}, nil, nil, nil, slog.Default()) + cookieIssuer := newMockCookieIssuer(t) + svc := NewOServiceService(config.Config{}, nil, nil, nil, slog.Default(), cookieIssuer) sess := newTestSession("test-user") want := wire.SNACMessage{ @@ -1530,7 +1566,8 @@ func TestOServiceService_IdleNotification(t *testing.T) { WhoAddedUser(params.userScreenName). Return(params.result) } - svc := NewOServiceService(config.Config{}, messageRelayer, feedbagManager, legacyBuddyListManager, slog.Default()) + cookieIssuer := newMockCookieIssuer(t) + svc := NewOServiceService(config.Config{}, messageRelayer, feedbagManager, legacyBuddyListManager, slog.Default(), cookieIssuer) haveErr := svc.IdleNotification(nil, tt.sess, tt.bodyIn) assert.ErrorIs(t, tt.wantErr, haveErr) @@ -1900,7 +1937,8 @@ func TestOServiceServiceForChat_ClientOnline(t *testing.T) { } func TestOServiceServiceForChatNav_HostOnline(t *testing.T) { - svc := NewOServiceServiceForChatNav(*NewOServiceService(config.Config{}, nil, nil, nil, slog.Default()), nil) + cookieIssuer := newMockCookieIssuer(t) + svc := NewOServiceServiceForChatNav(*NewOServiceService(config.Config{}, nil, nil, nil, slog.Default(), cookieIssuer), nil) want := wire.SNACMessage{ Frame: wire.SNACFrame{ @@ -1920,7 +1958,8 @@ func TestOServiceServiceForChatNav_HostOnline(t *testing.T) { } func TestOServiceServiceForAlert_HostOnline(t *testing.T) { - svc := NewOServiceServiceForAlert(*NewOServiceService(config.Config{}, nil, nil, nil, slog.Default())) + cookieIssuer := newMockCookieIssuer(t) + svc := NewOServiceServiceForAlert(*NewOServiceService(config.Config{}, nil, nil, nil, slog.Default(), cookieIssuer)) want := wire.SNACMessage{ Frame: wire.SNACFrame{ @@ -1940,7 +1979,8 @@ func TestOServiceServiceForAlert_HostOnline(t *testing.T) { } func TestOServiceService_SetPrivacyFlags(t *testing.T) { - svc := NewOServiceServiceForAlert(*NewOServiceService(config.Config{}, nil, nil, nil, slog.Default())) + cookieIssuer := newMockCookieIssuer(t) + svc := NewOServiceServiceForAlert(*NewOServiceService(config.Config{}, nil, nil, nil, slog.Default(), cookieIssuer)) body := wire.SNAC_0x01_0x14_OServiceSetPrivacyFlags{ PrivacyFlags: wire.OServicePrivacyFlagMember | wire.OServicePrivacyFlagIdle, } diff --git a/foodgroup/test_helpers.go b/foodgroup/test_helpers.go index 0892fe24..a7019c1a 100644 --- a/foodgroup/test_helpers.go +++ b/foodgroup/test_helpers.go @@ -19,6 +19,7 @@ type mockParams struct { profileManagerParams sessionManagerParams userManagerParams + cookieIssuerParams } // bartManagerParams is a helper struct that contains mock parameters for @@ -279,6 +280,14 @@ type whoAddedUserParams []struct { result []string } +// cookieIssuerParams is the list of parameters passed at the mock +// CookieIssuer.Issue call site +type cookieIssuerParams []struct { + data []byte + cookie []byte + err error +} + // sessOptWarning sets a warning level on the session object func sessOptWarning(level uint16) func(session *state.Session) { return func(session *state.Session) { @@ -286,20 +295,6 @@ func sessOptWarning(level uint16) func(session *state.Session) { } } -// sessOptCannedID sets a canned session ID ("user-session-id") on the session -// object -func sessOptCannedID(session *state.Session) { - session.SetID("user-session-id") -} - -// sessOptCannedID sets a canned session ID ("user-session-id") on the session -// object -func sessOptID(ID string) func(session *state.Session) { - return func(session *state.Session) { - session.SetID(ID) - } -} - // sessOptAwayMessage sets away message on the session object func sessOptAwayMessage(awayMessage string) func(session *state.Session) { return func(session *state.Session) { diff --git a/foodgroup/types.go b/foodgroup/types.go index 9902204f..54d8ba0f 100644 --- a/foodgroup/types.go +++ b/foodgroup/types.go @@ -80,7 +80,7 @@ type UserManager interface { type SessionManager interface { Empty() bool - AddSession(sessID string, screenName string) *state.Session + AddSession(screenName string) *state.Session RemoveSession(sess *state.Session) RetrieveSession(ID string) *state.Session } @@ -112,3 +112,7 @@ type BARTManager interface { BARTUpsert(itemHash []byte, payload []byte) error BARTRetrieve(itemHash []byte) ([]byte, error) } + +type CookieIssuer interface { + Issue(data []byte) ([]byte, error) +} diff --git a/server/oscar/alert.go b/server/oscar/alert.go deleted file mode 100644 index 25c2769d..00000000 --- a/server/oscar/alert.go +++ /dev/null @@ -1,111 +0,0 @@ -package oscar - -import ( - "context" - "errors" - "io" - "log/slog" - "net" - "os" - - "github.com/mk6i/retro-aim-server/config" - "github.com/mk6i/retro-aim-server/state" - "github.com/mk6i/retro-aim-server/wire" -) - -// AlertServer provides client connection lifecycle management for the Alert -// service. This server, whose handlers are all no-op, exists solely to satisfy -// AIM 4.x, which throws an error when it can't connect to the alert service. -type AlertServer struct { - AuthService - Handler - Logger *slog.Logger - OnlineNotifier - config.Config -} - -// Start starts a TCP server and listens for connections. The initial -// authentication handshake sequences are handled by this method. The remaining -// requests are relayed to Handler. -func (rt AlertServer) Start() { - addr := net.JoinHostPort("", rt.Config.AlertPort) - listener, err := net.Listen("tcp", addr) - if err != nil { - rt.Logger.Error("unable to bind ALERT server address", "err", err.Error()) - os.Exit(1) - } - defer listener.Close() - - rt.Logger.Info("starting ALERT service", "host", net.JoinHostPort(rt.Config.OSCARHost, rt.Config.AlertPort)) - - for { - conn, err := listener.Accept() - if err != nil { - rt.Logger.Error(err.Error()) - continue - } - ctx := context.Background() - ctx = context.WithValue(ctx, "ip", conn.RemoteAddr().String()) - rt.Logger.DebugContext(ctx, "accepted connection") - go func() { - if err := rt.handleNewConnection(ctx, conn); err != nil { - rt.Logger.Info("user session failed", "err", err.Error()) - } - }() - } -} - -func (rt AlertServer) handleNewConnection(ctx context.Context, rwc io.ReadWriteCloser) error { - flapc := wire.NewFlapClient(100, rwc, rwc) - - if err := flapc.SendSignonFrame(nil); err != nil { - return err - } - flap, err := flapc.ReceiveSignonFrame() - if err != nil { - return err - } - - var ok bool - sessionID, ok := flap.Slice(wire.OServiceTLVTagsLoginCookie) - if !ok { - return errors.New("unable to get session id from payload") - } - - bosSess, err := rt.RetrieveBOSSession(string(sessionID)) - if err != nil { - return err - } - if bosSess == nil { - return errors.New("session not found") - } - - defer func() { - bosSess.Close() - rwc.Close() - if err := rt.Signout(ctx, bosSess); err != nil { - rt.Logger.ErrorContext(ctx, "error notifying departure", "err", err.Error()) - } - }() - - ctx = context.WithValue(ctx, "screenName", bosSess.ScreenName()) - - msg := rt.OnlineNotifier.HostOnline() - if err := flapc.SendSNAC(msg.Frame, msg.Body); err != nil { - return err - } - - // We copy the session object here to make sure that - // dispatchIncomingMessages does not consume relayed messages produced by - // the BOS server. Without this hack, message consumption would be split - // between the BOS server and Alert server, which would result in - // incorrect sequence number generation, because each server has its own - // sequence counter. This hack can be removed by decoupling FLAP routing - // and message relaying, which are both performed in - // dispatchIncomingMessages. - sessCopy := state.NewSession() - sessCopy.SetScreenName(bosSess.ScreenName()) - sessCopy.SetID(bosSess.ID()) - - return dispatchIncomingMessages(ctx, sessCopy, flapc, rwc, rt.Logger, rt.Handler, rt.Config) -} diff --git a/server/oscar/alert_test.go b/server/oscar/alert_test.go deleted file mode 100644 index f53992b9..00000000 --- a/server/oscar/alert_test.go +++ /dev/null @@ -1,109 +0,0 @@ -package oscar - -import ( - "bytes" - "context" - "io" - "log/slog" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" - - "github.com/mk6i/retro-aim-server/state" - "github.com/mk6i/retro-aim-server/wire" -) - -func TestAlertServer_handleNewConnection(t *testing.T) { - sess := state.NewSession() - sess.SetID("login-cookie-1234") - - clientReader, serverWriter := io.Pipe() - serverReader, clientWriter := io.Pipe() - - go func() { - // < receive FLAPSignonFrame - flap := wire.FLAPFrame{} - assert.NoError(t, wire.Unmarshal(&flap, serverReader)) - buf, err := flap.ReadBody(serverReader) - assert.NoError(t, err) - flapSignonFrame := wire.FLAPSignonFrame{} - assert.NoError(t, wire.Unmarshal(&flapSignonFrame, buf)) - - // > send FLAPSignonFrame - flapSignonFrame = wire.FLAPSignonFrame{ - FLAPVersion: 1, - } - flapSignonFrame.Append(wire.NewTLV(wire.OServiceTLVTagsLoginCookie, []byte(sess.ID()))) - buf = &bytes.Buffer{} - assert.NoError(t, wire.Marshal(flapSignonFrame, buf)) - flap = wire.FLAPFrame{ - StartMarker: 42, - FrameType: wire.FLAPFrameSignon, - PayloadLength: uint16(buf.Len()), - } - assert.NoError(t, wire.Marshal(flap, serverWriter)) - _, err = serverWriter.Write(buf.Bytes()) - assert.NoError(t, err) - - // < receive SNAC_0x01_0x03_OServiceHostOnline - flap = wire.FLAPFrame{} - assert.NoError(t, wire.Unmarshal(&flap, serverReader)) - buf, err = flap.ReadBody(serverReader) - assert.NoError(t, err) - frame := wire.SNACFrame{} - assert.NoError(t, wire.Unmarshal(&frame, buf)) - body := wire.SNAC_0x01_0x03_OServiceHostOnline{} - assert.NoError(t, wire.Unmarshal(&body, buf)) - - // send the first request that should get relayed to BOSRouter.Handle - flapc := wire.NewFlapClient(0, nil, serverWriter) - frame = wire.SNACFrame{ - FoodGroup: wire.OService, - SubGroup: wire.OServiceClientOnline, - } - assert.NoError(t, flapc.SendSNAC(frame, struct{}{})) - assert.NoError(t, serverWriter.Close()) - }() - - authService := newMockAuthService(t) - authService.EXPECT(). - RetrieveBOSSession(sess.ID()). - Return(sess, nil) - authService.EXPECT(). - Signout(mock.Anything, sess). - Return(nil) - - onlineNotifier := newMockOnlineNotifier(t) - onlineNotifier.EXPECT(). - HostOnline(). - Return(wire.SNACMessage{ - Frame: wire.SNACFrame{ - FoodGroup: wire.OService, - SubGroup: wire.OServiceHostOnline, - }, - Body: wire.SNAC_0x01_0x03_OServiceHostOnline{}, - }) - - router := newMockHandler(t) - router.EXPECT(). - Handle(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything). - Run(func(ctx context.Context, sess *state.Session, inFrame wire.SNACFrame, r io.Reader, rw ResponseWriter) { - assert.Equal(t, wire.SNACFrame{ - FoodGroup: wire.OService, - SubGroup: wire.OServiceClientOnline, - }, inFrame) - }).Return(nil) - - rt := AlertServer{ - AuthService: authService, - Handler: router, - Logger: slog.Default(), - OnlineNotifier: onlineNotifier, - } - rwc := pipeRWC{ - PipeReader: clientReader, - PipeWriter: clientWriter, - } - assert.NoError(t, rt.handleNewConnection(context.Background(), rwc)) -} diff --git a/server/oscar/auth.go b/server/oscar/auth.go index e313024e..00d35072 100644 --- a/server/oscar/auth.go +++ b/server/oscar/auth.go @@ -16,10 +16,10 @@ import ( type AuthService interface { BUCPChallenge(bodyIn wire.SNAC_0x17_0x06_BUCPChallengeRequest, newUUID func() uuid.UUID) (wire.SNACMessage, error) - BUCPLogin(bodyIn wire.SNAC_0x17_0x02_BUCPLoginRequest, newUUID func() uuid.UUID, fn func(screenName string) (state.User, error)) (wire.SNACMessage, error) - FLAPLogin(frame wire.FLAPSignonFrame, newUUIDFn func() uuid.UUID, newUserFn func(screenName string) (state.User, error)) (wire.TLVRestBlock, error) - RetrieveBOSSession(sessionID string) (*state.Session, error) - RetrieveChatSession(loginCookie []byte) (*state.Session, error) + BUCPLogin(bodyIn wire.SNAC_0x17_0x02_BUCPLoginRequest, fn func(screenName string) (state.User, error)) (wire.SNACMessage, error) + FLAPLogin(frame wire.FLAPSignonFrame, newUserFn func(screenName string) (state.User, error)) (wire.TLVRestBlock, error) + RegisterBOSSession(sessionID string) (*state.Session, error) + RegisterChatSession(loginCookie []byte) (*state.Session, error) Signout(ctx context.Context, sess *state.Session) error SignoutChat(ctx context.Context, sess *state.Session) error } @@ -30,6 +30,7 @@ type AuthServer struct { AuthService config.Config Logger *slog.Logger + CookieCracker } // Start starts the authentication server and listens for new connections. @@ -79,7 +80,7 @@ func (rt AuthServer) handleNewConnection(rwc io.ReadWriteCloser) error { } func (rt AuthServer) processFLAPAuth(signonFrame wire.FLAPSignonFrame, flapc *wire.FlapClient) error { - tlv, err := rt.AuthService.FLAPLogin(signonFrame, uuid.New, state.NewStubUser) + tlv, err := rt.AuthService.FLAPLogin(signonFrame, state.NewStubUser) if err != nil { return err } @@ -105,7 +106,7 @@ func (rt AuthServer) processBUCPAuth(flapc *wire.FlapClient, err error) error { return err } - outSNAC, err = rt.BUCPLogin(loginRequest, uuid.New, state.NewStubUser) + outSNAC, err = rt.BUCPLogin(loginRequest, state.NewStubUser) if err != nil { return err } diff --git a/server/oscar/auth_test.go b/server/oscar/auth_test.go index 56f0a9f0..1346047b 100644 --- a/server/oscar/auth_test.go +++ b/server/oscar/auth_test.go @@ -80,7 +80,7 @@ func TestBUCPAuthService_handleNewConnection(t *testing.T) { Body: wire.SNAC_0x17_0x07_BUCPChallengeResponse{}, }, nil) authService.EXPECT(). - BUCPLogin(mock.Anything, mock.Anything, mock.Anything). + BUCPLogin(mock.Anything, mock.Anything). Return(wire.SNACMessage{ Frame: wire.SNACFrame{ FoodGroup: wire.BUCP, diff --git a/server/oscar/bos.go b/server/oscar/bos.go index e944a19a..cc4f2f11 100644 --- a/server/oscar/bos.go +++ b/server/oscar/bos.go @@ -1,7 +1,6 @@ package oscar import ( - "bytes" "context" "errors" "io" @@ -24,8 +23,10 @@ type OnlineNotifier interface { // service. type BOSServer struct { AuthService + CookieCracker Handler - Logger *slog.Logger + ListenAddr string + Logger *slog.Logger OnlineNotifier config.Config } @@ -34,15 +35,14 @@ type BOSServer struct { // authentication handshake sequences are handled by this method. The remaining // requests are relayed to BOSRouter. func (rt BOSServer) Start() { - addr := net.JoinHostPort("", rt.Config.BOSPort) - listener, err := net.Listen("tcp", addr) + listener, err := net.Listen("tcp", rt.ListenAddr) if err != nil { rt.Logger.Error("unable to bind BOS server address", "err", err.Error()) os.Exit(1) } defer listener.Close() - rt.Logger.Info("starting BOS service", "host", net.JoinHostPort(rt.Config.OSCARHost, rt.Config.BOSPort)) + rt.Logger.Info("starting service", "host", net.JoinHostPort(rt.Config.OSCARHost, rt.Config.BOSPort)) for { conn, err := listener.Accept() @@ -72,16 +72,17 @@ func (rt BOSServer) handleNewConnection(ctx context.Context, rwc io.ReadWriteClo return err } - var ok bool - sessionID, ok := flap.Slice(wire.OServiceTLVTagsLoginCookie) + authCookie, ok := flap.Slice(wire.OServiceTLVTagsLoginCookie) if !ok { return errors.New("unable to get session id from payload") } - // Trim the padding added to the auth cookie by the auth service. - sessionID = bytes.TrimRight(sessionID, "\x00") - sess, err := rt.RetrieveBOSSession(string(sessionID)) + token, err := rt.CookieCracker.Crack(authCookie) + if err != nil { + return err + } + sess, err := rt.RegisterBOSSession(string(token)) if err != nil { return err } diff --git a/server/oscar/bos_test.go b/server/oscar/bos_test.go index de0c2291..5366332a 100644 --- a/server/oscar/bos_test.go +++ b/server/oscar/bos_test.go @@ -30,7 +30,6 @@ func (m pipeRWC) Close() error { func TestBOSService_handleNewConnection(t *testing.T) { sess := state.NewSession() - sess.SetID("login-cookie-1234") clientReader, serverWriter := io.Pipe() serverReader, clientWriter := io.Pipe() @@ -48,11 +47,7 @@ func TestBOSService_handleNewConnection(t *testing.T) { flapSignonFrame = wire.FLAPSignonFrame{ FLAPVersion: 1, } - // create padded auth cookie - cookie := make([]byte, 220) - copy(cookie, sess.ID()) - flapSignonFrame.Append(wire.NewTLV(wire.OServiceTLVTagsLoginCookie, cookie)) - + flapSignonFrame.Append(wire.NewTLV(wire.OServiceTLVTagsLoginCookie, []byte("the-cookie"))) buf = &bytes.Buffer{} assert.NoError(t, wire.Marshal(flapSignonFrame, buf)) flap = wire.FLAPFrame{ @@ -86,7 +81,7 @@ func TestBOSService_handleNewConnection(t *testing.T) { authService := newMockAuthService(t) authService.EXPECT(). - RetrieveBOSSession(sess.ID()). + RegisterBOSSession("user_screen_name"). Return(sess, nil) authService.EXPECT(). Signout(mock.Anything, sess). @@ -103,6 +98,11 @@ func TestBOSService_handleNewConnection(t *testing.T) { Body: wire.SNAC_0x01_0x03_OServiceHostOnline{}, }) + cookieCracker := newMockCookieCracker(t) + cookieCracker.EXPECT(). + Crack([]byte("the-cookie")). + Return([]byte("user_screen_name"), nil) + router := newMockHandler(t) router.EXPECT(). Handle(mock.Anything, sess, mock.Anything, mock.Anything, mock.Anything). @@ -115,6 +115,7 @@ func TestBOSService_handleNewConnection(t *testing.T) { rt := BOSServer{ AuthService: authService, + CookieCracker: cookieCracker, Handler: router, Logger: slog.Default(), OnlineNotifier: onlineNotifier, diff --git a/server/oscar/chat.go b/server/oscar/chat.go index 231449c5..f86b13e9 100644 --- a/server/oscar/chat.go +++ b/server/oscar/chat.go @@ -21,6 +21,7 @@ type ChatServer struct { Logger *slog.Logger OnlineNotifier config.Config + CookieCracker } // Start creates a TCP server that implements that chat flow. @@ -68,7 +69,12 @@ func (rt ChatServer) handleNewConnection(ctx context.Context, rwc io.ReadWriteCl return errors.New("unable to get login cookie from payload") } - chatSess, err := rt.RetrieveChatSession(loginCookie) + token, err := rt.CookieCracker.Crack(loginCookie) + if err != nil { + return err + } + + chatSess, err := rt.RegisterChatSession(token) if err != nil { return err } diff --git a/server/oscar/chat_nav.go b/server/oscar/chat_nav.go deleted file mode 100644 index 68f0e6fd..00000000 --- a/server/oscar/chat_nav.go +++ /dev/null @@ -1,106 +0,0 @@ -package oscar - -import ( - "context" - "errors" - "io" - "log/slog" - "net" - "os" - - "github.com/mk6i/retro-aim-server/config" - "github.com/mk6i/retro-aim-server/state" - "github.com/mk6i/retro-aim-server/wire" -) - -// ChatNavServer provides client connection lifecycle management for the -// ChatNav service. This service is only used by AIM 4.x clients that make a -// separate ChatNav TCP connection. AIM 5.x clients call the ChatNav food group -// provided by BOS without creating an additional TCP connection. -type ChatNavServer struct { - AuthService - Handler - Logger *slog.Logger - OnlineNotifier - config.Config -} - -// Start starts a TCP server and listens for ChatNav connections. -func (rt ChatNavServer) Start() { - addr := net.JoinHostPort("", rt.Config.ChatNavPort) - listener, err := net.Listen("tcp", addr) - if err != nil { - rt.Logger.Error("unable to bind chat nav server address", "err", err.Error()) - os.Exit(1) - } - defer listener.Close() - - rt.Logger.Info("starting chat nav service", "host", net.JoinHostPort(rt.Config.OSCARHost, rt.Config.ChatNavPort)) - - for { - conn, err := listener.Accept() - if err != nil { - rt.Logger.Error(err.Error()) - continue - } - ctx := context.Background() - ctx = context.WithValue(ctx, "ip", conn.RemoteAddr().String()) - rt.Logger.DebugContext(ctx, "accepted connection") - go func() { - if err := rt.handleNewConnection(ctx, conn); err != nil { - rt.Logger.Info("user session failed", "err", err.Error()) - } - }() - } -} - -func (rt ChatNavServer) handleNewConnection(ctx context.Context, rwc io.ReadWriteCloser) error { - flapc := wire.NewFlapClient(100, rwc, rwc) - - if err := flapc.SendSignonFrame(nil); err != nil { - return err - } - flap, err := flapc.ReceiveSignonFrame() - if err != nil { - return err - } - - var ok bool - sessionID, ok := flap.Slice(wire.OServiceTLVTagsLoginCookie) - if !ok { - return errors.New("unable to get session id from payload") - } - - bosSess, err := rt.RetrieveBOSSession(string(sessionID)) - if err != nil { - return err - } - if bosSess == nil { - return errors.New("session not found") - } - - defer func() { - rwc.Close() - }() - - ctx = context.WithValue(ctx, "screenName", bosSess.ScreenName()) - - msg := rt.OnlineNotifier.HostOnline() - if err := flapc.SendSNAC(msg.Frame, msg.Body); err != nil { - return err - } - - // We copy the session object here to make sure that - // dispatchIncomingMessages does not consume relayed messages produced by - // the BOS server. Without this hack, message consumption would be split - // between the BOS server and ChatNav server, which would result in - // incorrect sequence number generation, because each server has its own - // sequence counter. This hack can be removed by decoupling FLAP routing - // and message relaying, which are both performed in - // dispatchIncomingMessages. - sessCopy := state.NewSession() - sessCopy.SetScreenName(bosSess.ScreenName()) - sessCopy.SetID(bosSess.ID()) - - return dispatchIncomingMessages(ctx, sessCopy, flapc, rwc, rt.Logger, rt.Handler, rt.Config) -} diff --git a/server/oscar/chat_nav_test.go b/server/oscar/chat_nav_test.go deleted file mode 100644 index 0745a35f..00000000 --- a/server/oscar/chat_nav_test.go +++ /dev/null @@ -1,108 +0,0 @@ -package oscar - -import ( - "bytes" - "context" - "io" - "log/slog" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" - - "github.com/mk6i/retro-aim-server/state" - "github.com/mk6i/retro-aim-server/wire" -) - -func TestChatNavServer_handleNewConnection(t *testing.T) { - sess := state.NewSession() - sess.SetID("login-cookie-1234") - - clientReader, serverWriter := io.Pipe() - serverReader, clientWriter := io.Pipe() - - go func() { - // < receive FLAPSignonFrame - flap := wire.FLAPFrame{} - assert.NoError(t, wire.Unmarshal(&flap, serverReader)) - buf, err := flap.ReadBody(serverReader) - assert.NoError(t, err) - flapSignonFrame := wire.FLAPSignonFrame{} - assert.NoError(t, wire.Unmarshal(&flapSignonFrame, buf)) - - // > send FLAPSignonFrame - flapSignonFrame = wire.FLAPSignonFrame{ - FLAPVersion: 1, - } - flapSignonFrame.Append(wire.NewTLV(wire.OServiceTLVTagsLoginCookie, []byte(sess.ID()))) - buf = &bytes.Buffer{} - assert.NoError(t, wire.Marshal(flapSignonFrame, buf)) - flap = wire.FLAPFrame{ - StartMarker: 42, - FrameType: wire.FLAPFrameSignon, - PayloadLength: uint16(buf.Len()), - } - assert.NoError(t, wire.Marshal(flap, serverWriter)) - _, err = serverWriter.Write(buf.Bytes()) - assert.NoError(t, err) - - // < receive SNAC_0x01_0x03_OServiceHostOnline - flap = wire.FLAPFrame{} - assert.NoError(t, wire.Unmarshal(&flap, serverReader)) - buf, err = flap.ReadBody(serverReader) - assert.NoError(t, err) - frame := wire.SNACFrame{} - assert.NoError(t, wire.Unmarshal(&frame, buf)) - body := wire.SNAC_0x01_0x03_OServiceHostOnline{} - assert.NoError(t, wire.Unmarshal(&body, buf)) - - // send the first request that should get relayed to Handler - flapc := wire.NewFlapClient(0, nil, serverWriter) - frame = wire.SNACFrame{ - FoodGroup: wire.OService, - SubGroup: wire.OServiceClientOnline, - } - assert.NoError(t, flapc.SendSNAC(frame, struct{}{})) - assert.NoError(t, serverWriter.Close()) - }() - - authService := newMockAuthService(t) - authService.EXPECT(). - RetrieveBOSSession(sess.ID()). - Return(sess, nil) - - onlineNotifier := newMockOnlineNotifier(t) - onlineNotifier.EXPECT(). - HostOnline(). - Return(wire.SNACMessage{ - Frame: wire.SNACFrame{ - FoodGroup: wire.OService, - SubGroup: wire.OServiceHostOnline, - }, - Body: wire.SNAC_0x01_0x03_OServiceHostOnline{}, - }) - - // we can't assert that the same sess instance created above is passed as a - // parameter to Handle because the session instance is copied - router := newMockHandler(t) - router.EXPECT(). - Handle(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything). - Run(func(ctx context.Context, sess *state.Session, inFrame wire.SNACFrame, r io.Reader, rw ResponseWriter) { - assert.Equal(t, wire.SNACFrame{ - FoodGroup: wire.OService, - SubGroup: wire.OServiceClientOnline, - }, inFrame) - }).Return(nil) - - rt := ChatNavServer{ - AuthService: authService, - Handler: router, - Logger: slog.Default(), - OnlineNotifier: onlineNotifier, - } - rwc := pipeRWC{ - PipeReader: clientReader, - PipeWriter: clientWriter, - } - assert.NoError(t, rt.handleNewConnection(context.Background(), rwc)) -} diff --git a/server/oscar/chat_test.go b/server/oscar/chat_test.go index 06e6f329..38fc465d 100644 --- a/server/oscar/chat_test.go +++ b/server/oscar/chat_test.go @@ -16,7 +16,6 @@ import ( func TestChatService_handleNewConnection(t *testing.T) { sess := state.NewSession() - sess.SetID("session-id-1234") clientReader, serverWriter := io.Pipe() serverReader, clientWriter := io.Pipe() @@ -68,7 +67,7 @@ func TestChatService_handleNewConnection(t *testing.T) { authService := newMockAuthService(t) authService.EXPECT(). - RetrieveChatSession([]byte(`the-chat-login-cookie`)). + RegisterChatSession([]byte(`user-screen-name`)). Return(sess, nil) authService.EXPECT(). SignoutChat(mock.Anything, sess). @@ -85,6 +84,11 @@ func TestChatService_handleNewConnection(t *testing.T) { Body: wire.SNAC_0x01_0x03_OServiceHostOnline{}, }) + cookieCracker := newMockCookieCracker(t) + cookieCracker.EXPECT(). + Crack([]byte(`the-chat-login-cookie`)). + Return([]byte(`user-screen-name`), nil) + bosRouter := newMockHandler(t) bosRouter.EXPECT(). Handle(mock.Anything, sess, mock.Anything, mock.Anything, mock.Anything). @@ -92,6 +96,7 @@ func TestChatService_handleNewConnection(t *testing.T) { rt := ChatServer{ AuthService: authService, + CookieCracker: cookieCracker, Handler: bosRouter, Logger: slog.Default(), OnlineNotifier: onlineNotifier, diff --git a/server/oscar/connection_test.go b/server/oscar/connection_test.go index 120dea2b..2843ba89 100644 --- a/server/oscar/connection_test.go +++ b/server/oscar/connection_test.go @@ -18,7 +18,7 @@ import ( func TestHandleChatConnection_MessageRelay(t *testing.T) { sessionManager := state.NewInMemorySessionManager(slog.Default()) // add a user to session that will receive relayed messages - sess := sessionManager.AddSession("bob-sess-id", "bob") + sess := sessionManager.AddSession("bob") // start the server connection handler in the background serverReader, _ := io.Pipe() @@ -91,7 +91,7 @@ func TestHandleChatConnection_MessageRelay(t *testing.T) { func TestHandleChatConnection_ClientRequest(t *testing.T) { sessionManager := state.NewInMemorySessionManager(slog.Default()) // add session so that the function can terminate upon closure - sess := sessionManager.AddSession("bob-sess-id", "bob") + sess := sessionManager.AddSession("bob") inboundMsgs := []wire.SNACMessage{ { diff --git a/server/oscar/mock_auth_test.go b/server/oscar/mock_auth_test.go index 8f2e4533..73a2da66 100644 --- a/server/oscar/mock_auth_test.go +++ b/server/oscar/mock_auth_test.go @@ -83,9 +83,9 @@ func (_c *mockAuthService_BUCPChallenge_Call) RunAndReturn(run func(wire.SNAC_0x return _c } -// BUCPLogin provides a mock function with given fields: bodyIn, newUUID, fn -func (_m *mockAuthService) BUCPLogin(bodyIn wire.SNAC_0x17_0x02_BUCPLoginRequest, newUUID func() uuid.UUID, fn func(string) (state.User, error)) (wire.SNACMessage, error) { - ret := _m.Called(bodyIn, newUUID, fn) +// BUCPLogin provides a mock function with given fields: bodyIn, fn +func (_m *mockAuthService) BUCPLogin(bodyIn wire.SNAC_0x17_0x02_BUCPLoginRequest, fn func(string) (state.User, error)) (wire.SNACMessage, error) { + ret := _m.Called(bodyIn, fn) if len(ret) == 0 { panic("no return value specified for BUCPLogin") @@ -93,17 +93,17 @@ func (_m *mockAuthService) BUCPLogin(bodyIn wire.SNAC_0x17_0x02_BUCPLoginRequest var r0 wire.SNACMessage var r1 error - if rf, ok := ret.Get(0).(func(wire.SNAC_0x17_0x02_BUCPLoginRequest, func() uuid.UUID, func(string) (state.User, error)) (wire.SNACMessage, error)); ok { - return rf(bodyIn, newUUID, fn) + if rf, ok := ret.Get(0).(func(wire.SNAC_0x17_0x02_BUCPLoginRequest, func(string) (state.User, error)) (wire.SNACMessage, error)); ok { + return rf(bodyIn, fn) } - if rf, ok := ret.Get(0).(func(wire.SNAC_0x17_0x02_BUCPLoginRequest, func() uuid.UUID, func(string) (state.User, error)) wire.SNACMessage); ok { - r0 = rf(bodyIn, newUUID, fn) + if rf, ok := ret.Get(0).(func(wire.SNAC_0x17_0x02_BUCPLoginRequest, func(string) (state.User, error)) wire.SNACMessage); ok { + r0 = rf(bodyIn, fn) } else { r0 = ret.Get(0).(wire.SNACMessage) } - if rf, ok := ret.Get(1).(func(wire.SNAC_0x17_0x02_BUCPLoginRequest, func() uuid.UUID, func(string) (state.User, error)) error); ok { - r1 = rf(bodyIn, newUUID, fn) + if rf, ok := ret.Get(1).(func(wire.SNAC_0x17_0x02_BUCPLoginRequest, func(string) (state.User, error)) error); ok { + r1 = rf(bodyIn, fn) } else { r1 = ret.Error(1) } @@ -118,15 +118,14 @@ type mockAuthService_BUCPLogin_Call struct { // BUCPLogin is a helper method to define mock.On call // - bodyIn wire.SNAC_0x17_0x02_BUCPLoginRequest -// - newUUID func() uuid.UUID // - fn func(string)(state.User , error) -func (_e *mockAuthService_Expecter) BUCPLogin(bodyIn interface{}, newUUID interface{}, fn interface{}) *mockAuthService_BUCPLogin_Call { - return &mockAuthService_BUCPLogin_Call{Call: _e.mock.On("BUCPLogin", bodyIn, newUUID, fn)} +func (_e *mockAuthService_Expecter) BUCPLogin(bodyIn interface{}, fn interface{}) *mockAuthService_BUCPLogin_Call { + return &mockAuthService_BUCPLogin_Call{Call: _e.mock.On("BUCPLogin", bodyIn, fn)} } -func (_c *mockAuthService_BUCPLogin_Call) Run(run func(bodyIn wire.SNAC_0x17_0x02_BUCPLoginRequest, newUUID func() uuid.UUID, fn func(string) (state.User, error))) *mockAuthService_BUCPLogin_Call { +func (_c *mockAuthService_BUCPLogin_Call) Run(run func(bodyIn wire.SNAC_0x17_0x02_BUCPLoginRequest, fn func(string) (state.User, error))) *mockAuthService_BUCPLogin_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(wire.SNAC_0x17_0x02_BUCPLoginRequest), args[1].(func() uuid.UUID), args[2].(func(string) (state.User, error))) + run(args[0].(wire.SNAC_0x17_0x02_BUCPLoginRequest), args[1].(func(string) (state.User, error))) }) return _c } @@ -136,14 +135,14 @@ func (_c *mockAuthService_BUCPLogin_Call) Return(_a0 wire.SNACMessage, _a1 error return _c } -func (_c *mockAuthService_BUCPLogin_Call) RunAndReturn(run func(wire.SNAC_0x17_0x02_BUCPLoginRequest, func() uuid.UUID, func(string) (state.User, error)) (wire.SNACMessage, error)) *mockAuthService_BUCPLogin_Call { +func (_c *mockAuthService_BUCPLogin_Call) RunAndReturn(run func(wire.SNAC_0x17_0x02_BUCPLoginRequest, func(string) (state.User, error)) (wire.SNACMessage, error)) *mockAuthService_BUCPLogin_Call { _c.Call.Return(run) return _c } -// FLAPLogin provides a mock function with given fields: frame, newUUIDFn, newUserFn -func (_m *mockAuthService) FLAPLogin(frame wire.FLAPSignonFrame, newUUIDFn func() uuid.UUID, newUserFn func(string) (state.User, error)) (wire.TLVRestBlock, error) { - ret := _m.Called(frame, newUUIDFn, newUserFn) +// FLAPLogin provides a mock function with given fields: frame, newUserFn +func (_m *mockAuthService) FLAPLogin(frame wire.FLAPSignonFrame, newUserFn func(string) (state.User, error)) (wire.TLVRestBlock, error) { + ret := _m.Called(frame, newUserFn) if len(ret) == 0 { panic("no return value specified for FLAPLogin") @@ -151,17 +150,17 @@ func (_m *mockAuthService) FLAPLogin(frame wire.FLAPSignonFrame, newUUIDFn func( var r0 wire.TLVRestBlock var r1 error - if rf, ok := ret.Get(0).(func(wire.FLAPSignonFrame, func() uuid.UUID, func(string) (state.User, error)) (wire.TLVRestBlock, error)); ok { - return rf(frame, newUUIDFn, newUserFn) + if rf, ok := ret.Get(0).(func(wire.FLAPSignonFrame, func(string) (state.User, error)) (wire.TLVRestBlock, error)); ok { + return rf(frame, newUserFn) } - if rf, ok := ret.Get(0).(func(wire.FLAPSignonFrame, func() uuid.UUID, func(string) (state.User, error)) wire.TLVRestBlock); ok { - r0 = rf(frame, newUUIDFn, newUserFn) + if rf, ok := ret.Get(0).(func(wire.FLAPSignonFrame, func(string) (state.User, error)) wire.TLVRestBlock); ok { + r0 = rf(frame, newUserFn) } else { r0 = ret.Get(0).(wire.TLVRestBlock) } - if rf, ok := ret.Get(1).(func(wire.FLAPSignonFrame, func() uuid.UUID, func(string) (state.User, error)) error); ok { - r1 = rf(frame, newUUIDFn, newUserFn) + if rf, ok := ret.Get(1).(func(wire.FLAPSignonFrame, func(string) (state.User, error)) error); ok { + r1 = rf(frame, newUserFn) } else { r1 = ret.Error(1) } @@ -176,15 +175,14 @@ type mockAuthService_FLAPLogin_Call struct { // FLAPLogin is a helper method to define mock.On call // - frame wire.FLAPSignonFrame -// - newUUIDFn func() uuid.UUID // - newUserFn func(string)(state.User , error) -func (_e *mockAuthService_Expecter) FLAPLogin(frame interface{}, newUUIDFn interface{}, newUserFn interface{}) *mockAuthService_FLAPLogin_Call { - return &mockAuthService_FLAPLogin_Call{Call: _e.mock.On("FLAPLogin", frame, newUUIDFn, newUserFn)} +func (_e *mockAuthService_Expecter) FLAPLogin(frame interface{}, newUserFn interface{}) *mockAuthService_FLAPLogin_Call { + return &mockAuthService_FLAPLogin_Call{Call: _e.mock.On("FLAPLogin", frame, newUserFn)} } -func (_c *mockAuthService_FLAPLogin_Call) Run(run func(frame wire.FLAPSignonFrame, newUUIDFn func() uuid.UUID, newUserFn func(string) (state.User, error))) *mockAuthService_FLAPLogin_Call { +func (_c *mockAuthService_FLAPLogin_Call) Run(run func(frame wire.FLAPSignonFrame, newUserFn func(string) (state.User, error))) *mockAuthService_FLAPLogin_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(wire.FLAPSignonFrame), args[1].(func() uuid.UUID), args[2].(func(string) (state.User, error))) + run(args[0].(wire.FLAPSignonFrame), args[1].(func(string) (state.User, error))) }) return _c } @@ -194,17 +192,17 @@ func (_c *mockAuthService_FLAPLogin_Call) Return(_a0 wire.TLVRestBlock, _a1 erro return _c } -func (_c *mockAuthService_FLAPLogin_Call) RunAndReturn(run func(wire.FLAPSignonFrame, func() uuid.UUID, func(string) (state.User, error)) (wire.TLVRestBlock, error)) *mockAuthService_FLAPLogin_Call { +func (_c *mockAuthService_FLAPLogin_Call) RunAndReturn(run func(wire.FLAPSignonFrame, func(string) (state.User, error)) (wire.TLVRestBlock, error)) *mockAuthService_FLAPLogin_Call { _c.Call.Return(run) return _c } -// RetrieveBOSSession provides a mock function with given fields: sessionID -func (_m *mockAuthService) RetrieveBOSSession(sessionID string) (*state.Session, error) { +// RegisterBOSSession provides a mock function with given fields: sessionID +func (_m *mockAuthService) RegisterBOSSession(sessionID string) (*state.Session, error) { ret := _m.Called(sessionID) if len(ret) == 0 { - panic("no return value specified for RetrieveBOSSession") + panic("no return value specified for RegisterBOSSession") } var r0 *state.Session @@ -229,40 +227,40 @@ func (_m *mockAuthService) RetrieveBOSSession(sessionID string) (*state.Session, return r0, r1 } -// mockAuthService_RetrieveBOSSession_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RetrieveBOSSession' -type mockAuthService_RetrieveBOSSession_Call struct { +// mockAuthService_RegisterBOSSession_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RegisterBOSSession' +type mockAuthService_RegisterBOSSession_Call struct { *mock.Call } -// RetrieveBOSSession is a helper method to define mock.On call +// RegisterBOSSession is a helper method to define mock.On call // - sessionID string -func (_e *mockAuthService_Expecter) RetrieveBOSSession(sessionID interface{}) *mockAuthService_RetrieveBOSSession_Call { - return &mockAuthService_RetrieveBOSSession_Call{Call: _e.mock.On("RetrieveBOSSession", sessionID)} +func (_e *mockAuthService_Expecter) RegisterBOSSession(sessionID interface{}) *mockAuthService_RegisterBOSSession_Call { + return &mockAuthService_RegisterBOSSession_Call{Call: _e.mock.On("RegisterBOSSession", sessionID)} } -func (_c *mockAuthService_RetrieveBOSSession_Call) Run(run func(sessionID string)) *mockAuthService_RetrieveBOSSession_Call { +func (_c *mockAuthService_RegisterBOSSession_Call) Run(run func(sessionID string)) *mockAuthService_RegisterBOSSession_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(string)) }) return _c } -func (_c *mockAuthService_RetrieveBOSSession_Call) Return(_a0 *state.Session, _a1 error) *mockAuthService_RetrieveBOSSession_Call { +func (_c *mockAuthService_RegisterBOSSession_Call) Return(_a0 *state.Session, _a1 error) *mockAuthService_RegisterBOSSession_Call { _c.Call.Return(_a0, _a1) return _c } -func (_c *mockAuthService_RetrieveBOSSession_Call) RunAndReturn(run func(string) (*state.Session, error)) *mockAuthService_RetrieveBOSSession_Call { +func (_c *mockAuthService_RegisterBOSSession_Call) RunAndReturn(run func(string) (*state.Session, error)) *mockAuthService_RegisterBOSSession_Call { _c.Call.Return(run) return _c } -// RetrieveChatSession provides a mock function with given fields: loginCookie -func (_m *mockAuthService) RetrieveChatSession(loginCookie []byte) (*state.Session, error) { +// RegisterChatSession provides a mock function with given fields: loginCookie +func (_m *mockAuthService) RegisterChatSession(loginCookie []byte) (*state.Session, error) { ret := _m.Called(loginCookie) if len(ret) == 0 { - panic("no return value specified for RetrieveChatSession") + panic("no return value specified for RegisterChatSession") } var r0 *state.Session @@ -287,30 +285,30 @@ func (_m *mockAuthService) RetrieveChatSession(loginCookie []byte) (*state.Sessi return r0, r1 } -// mockAuthService_RetrieveChatSession_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RetrieveChatSession' -type mockAuthService_RetrieveChatSession_Call struct { +// mockAuthService_RegisterChatSession_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RegisterChatSession' +type mockAuthService_RegisterChatSession_Call struct { *mock.Call } -// RetrieveChatSession is a helper method to define mock.On call +// RegisterChatSession is a helper method to define mock.On call // - loginCookie []byte -func (_e *mockAuthService_Expecter) RetrieveChatSession(loginCookie interface{}) *mockAuthService_RetrieveChatSession_Call { - return &mockAuthService_RetrieveChatSession_Call{Call: _e.mock.On("RetrieveChatSession", loginCookie)} +func (_e *mockAuthService_Expecter) RegisterChatSession(loginCookie interface{}) *mockAuthService_RegisterChatSession_Call { + return &mockAuthService_RegisterChatSession_Call{Call: _e.mock.On("RegisterChatSession", loginCookie)} } -func (_c *mockAuthService_RetrieveChatSession_Call) Run(run func(loginCookie []byte)) *mockAuthService_RetrieveChatSession_Call { +func (_c *mockAuthService_RegisterChatSession_Call) Run(run func(loginCookie []byte)) *mockAuthService_RegisterChatSession_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].([]byte)) }) return _c } -func (_c *mockAuthService_RetrieveChatSession_Call) Return(_a0 *state.Session, _a1 error) *mockAuthService_RetrieveChatSession_Call { +func (_c *mockAuthService_RegisterChatSession_Call) Return(_a0 *state.Session, _a1 error) *mockAuthService_RegisterChatSession_Call { _c.Call.Return(_a0, _a1) return _c } -func (_c *mockAuthService_RetrieveChatSession_Call) RunAndReturn(run func([]byte) (*state.Session, error)) *mockAuthService_RetrieveChatSession_Call { +func (_c *mockAuthService_RegisterChatSession_Call) RunAndReturn(run func([]byte) (*state.Session, error)) *mockAuthService_RegisterChatSession_Call { _c.Call.Return(run) return _c } diff --git a/server/oscar/mock_cookie_cracker_test.go b/server/oscar/mock_cookie_cracker_test.go new file mode 100644 index 00000000..9269aa89 --- /dev/null +++ b/server/oscar/mock_cookie_cracker_test.go @@ -0,0 +1,90 @@ +// Code generated by mockery v2.40.1. DO NOT EDIT. + +package oscar + +import mock "github.com/stretchr/testify/mock" + +// mockCookieCracker is an autogenerated mock type for the CookieCracker type +type mockCookieCracker struct { + mock.Mock +} + +type mockCookieCracker_Expecter struct { + mock *mock.Mock +} + +func (_m *mockCookieCracker) EXPECT() *mockCookieCracker_Expecter { + return &mockCookieCracker_Expecter{mock: &_m.Mock} +} + +// Crack provides a mock function with given fields: data +func (_m *mockCookieCracker) Crack(data []byte) ([]byte, error) { + ret := _m.Called(data) + + if len(ret) == 0 { + panic("no return value specified for Crack") + } + + var r0 []byte + var r1 error + if rf, ok := ret.Get(0).(func([]byte) ([]byte, error)); ok { + return rf(data) + } + if rf, ok := ret.Get(0).(func([]byte) []byte); ok { + r0 = rf(data) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]byte) + } + } + + if rf, ok := ret.Get(1).(func([]byte) error); ok { + r1 = rf(data) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// mockCookieCracker_Crack_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Crack' +type mockCookieCracker_Crack_Call struct { + *mock.Call +} + +// Crack is a helper method to define mock.On call +// - data []byte +func (_e *mockCookieCracker_Expecter) Crack(data interface{}) *mockCookieCracker_Crack_Call { + return &mockCookieCracker_Crack_Call{Call: _e.mock.On("Crack", data)} +} + +func (_c *mockCookieCracker_Crack_Call) Run(run func(data []byte)) *mockCookieCracker_Crack_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].([]byte)) + }) + return _c +} + +func (_c *mockCookieCracker_Crack_Call) Return(_a0 []byte, _a1 error) *mockCookieCracker_Crack_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *mockCookieCracker_Crack_Call) RunAndReturn(run func([]byte) ([]byte, error)) *mockCookieCracker_Crack_Call { + _c.Call.Return(run) + return _c +} + +// newMockCookieCracker creates a new instance of mockCookieCracker. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func newMockCookieCracker(t interface { + mock.TestingT + Cleanup(func()) +}) *mockCookieCracker { + mock := &mockCookieCracker{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/server/oscar/types.go b/server/oscar/types.go new file mode 100644 index 00000000..7855dac2 --- /dev/null +++ b/server/oscar/types.go @@ -0,0 +1,5 @@ +package oscar + +type CookieCracker interface { + Crack(data []byte) ([]byte, error) +} diff --git a/state/cookie.go b/state/cookie.go new file mode 100644 index 00000000..5142d29a --- /dev/null +++ b/state/cookie.go @@ -0,0 +1,111 @@ +package state + +import ( + "bytes" + "crypto/hmac" + "crypto/rand" + "crypto/sha256" + "errors" + "fmt" + "io" + "time" + + "github.com/mk6i/retro-aim-server/wire" +) + +// authCookieLen is the fixed auth cookie length. +const authCookieLen = 256 + +func NewHMACCookieBaker() (HMACCookieBaker, error) { + cb := HMACCookieBaker{} + cb.key = make([]byte, 32) + if _, err := io.ReadFull(rand.Reader, cb.key); err != nil { + return cb, fmt.Errorf("cannot generate random HMAC key: %w", err) + } + return cb, nil +} + +type HMACCookieBaker struct { + key []byte +} + +func (c HMACCookieBaker) Issue(data []byte) ([]byte, error) { + payload := hmacTokenPayload{ + Expiry: uint32(time.Now().Add(1 * time.Minute).Unix()), + Data: data, + } + buf := &bytes.Buffer{} + if err := wire.Marshal(payload, buf); err != nil { + return nil, fmt.Errorf("unable to marshal auth authCookie: %w", err) + } + + hmacTok := hmacToken{ + Data: buf.Bytes(), + } + hmacTok.hash(c.key) + + buf.Reset() + + if err := wire.Marshal(hmacTok, buf); err != nil { + return nil, fmt.Errorf("unable to marshal auth authCookie: %w", err) + } + + // Some clients (such as perl NET::OSCAR) expect the auth cookie to be + // exactly 256 bytes, even though the cookie is stored in a + // variable-length TLV. Pad the auth cookie to make sure it's exactly + // 256 bytes. + if buf.Len() > authCookieLen { + return nil, fmt.Errorf("sess is too long, expect 256 bytes, got %d", buf.Len()) + } + buf.Write(make([]byte, authCookieLen-buf.Len())) + + return buf.Bytes(), nil +} + +func (c HMACCookieBaker) Crack(data []byte) ([]byte, error) { + hmacTok := hmacToken{} + if err := wire.Unmarshal(&hmacTok, bytes.NewBuffer(data)); err != nil { + return nil, fmt.Errorf("unable to unmarshal HMAC cooie: %w", err) + } + + if !hmacTok.validate(c.key) { + return nil, errors.New("invalid HMAC cookie") + } + + payload := hmacTokenPayload{} + if err := wire.Unmarshal(&payload, bytes.NewBuffer(hmacTok.Data)); err != nil { + return nil, fmt.Errorf("unable to unmarshal HMAC cookie payload: %w", err) + } + + expiry := time.Unix(int64(payload.Expiry), 0) + if expiry.Before(time.Now()) { + return nil, errors.New("HMAC cookie expired") + } + + return payload.Data, nil +} + +type hmacTokenPayload struct { + Expiry uint32 + Data []byte `len_prefix:"uint16"` +} + +type hmacToken struct { + Data []byte `len_prefix:"uint16"` + Sig []byte `len_prefix:"uint16"` +} + +func (h *hmacToken) hash(key []byte) { + hs := hmac.New(sha256.New, key) + if _, err := hs.Write(h.Data); err != nil { + // according to Hash interface, Write() should never return an error + panic("unable to compute hmac token") + } + h.Sig = hs.Sum(nil) +} + +func (h *hmacToken) validate(key []byte) bool { + cp := *h + cp.hash(key) + return hmac.Equal(h.Sig, cp.Sig) +} diff --git a/state/session.go b/state/session.go index 269e0dc3..3b58dc82 100644 --- a/state/session.go +++ b/state/session.go @@ -26,7 +26,6 @@ type Session struct { awayMessage string chatRoomCookie string closed bool - id string idle bool idleTime time.Time invisible bool @@ -89,20 +88,6 @@ func (s *Session) ScreenName() string { return s.screenName } -// SetID sets the user's session ID. -func (s *Session) SetID(ID string) { - s.mutex.Lock() - defer s.mutex.Unlock() - s.id = ID -} - -// ID returns the user's session ID. -func (s *Session) ID() string { - s.mutex.RLock() - defer s.mutex.RUnlock() - return s.id -} - // SetSignonTime sets the user's sign-ontime. func (s *Session) SetSignonTime(t time.Time) { s.mutex.Lock() diff --git a/state/session_manager.go b/state/session_manager.go index 3e4384e2..24d75933 100644 --- a/state/session_manager.go +++ b/state/session_manager.go @@ -77,7 +77,7 @@ func (s *InMemorySessionManager) maybeRelayMessage(ctx context.Context, msg wire // AddSession adds a new session to the pool. It replaces an existing session // with a matching screen name, ensuring that each screen name is unique in the // pool. -func (s *InMemorySessionManager) AddSession(sessID string, screenName string) *Session { +func (s *InMemorySessionManager) AddSession(screenName string) *Session { s.mapMutex.Lock() defer s.mapMutex.Unlock() @@ -89,15 +89,14 @@ func (s *InMemorySessionManager) AddSession(sessID string, screenName string) *S for _, sess := range s.store { if screenName == sess.ScreenName() { sess.Close() - delete(s.store, sess.ID()) + delete(s.store, screenName) break } } sess := NewSession() - sess.SetID(sessID) sess.SetScreenName(screenName) - s.store[sess.ID()] = sess + s.store[sess.ScreenName()] = sess return sess } @@ -105,7 +104,7 @@ func (s *InMemorySessionManager) AddSession(sessID string, screenName string) *S func (s *InMemorySessionManager) RemoveSession(sess *Session) { s.mapMutex.Lock() defer s.mapMutex.Unlock() - delete(s.store, sess.ID()) + delete(s.store, sess.ScreenName()) } // RetrieveSession finds a session with a matching sessionID. Returns nil if diff --git a/state/session_manager_test.go b/state/session_manager_test.go index f1793ae5..09d0f48d 100644 --- a/state/session_manager_test.go +++ b/state/session_manager_test.go @@ -13,11 +13,11 @@ import ( func TestInMemorySessionManager_AddSession(t *testing.T) { sm := NewInMemorySessionManager(slog.Default()) - want1 := sm.AddSession("sess-id-1", "user-screen-name") + want1 := sm.AddSession("user-screen-name") have1 := sm.RetrieveByScreenName("user-screen-name") assert.Same(t, want1, have1) - want2 := sm.AddSession("sess-id-2", "user-screen-name") + want2 := sm.AddSession("user-screen-name") have2 := sm.RetrieveByScreenName("user-screen-name") assert.Same(t, want2, have2) @@ -37,11 +37,9 @@ func TestInMemorySessionManager_Remove(t *testing.T) { name: "remove user that exists", given: []*Session{ { - id: "sess-id-1", screenName: "user-screen-name-1", }, { - id: "sess-id-2", screenName: "user-screen-name-2", }, }, @@ -56,7 +54,7 @@ func TestInMemorySessionManager_Remove(t *testing.T) { sm := NewInMemorySessionManager(slog.Default()) for _, sess := range tt.given { - sm.AddSession(sess.id, sess.screenName) + sm.AddSession(sess.screenName) } sm.RemoveSession(sm.RetrieveByScreenName(tt.remove)) @@ -79,7 +77,6 @@ func TestInMemorySessionManager_Empty(t *testing.T) { name: "session manager is not empty", given: []*Session{ { - id: "sess-id-1", screenName: "user-screen-name-1", }, }, @@ -96,7 +93,7 @@ func TestInMemorySessionManager_Empty(t *testing.T) { sm := NewInMemorySessionManager(slog.Default()) for _, sess := range tt.given { - sm.AddSession(sess.id, sess.screenName) + sm.AddSession(sess.screenName) } have := sm.Empty() @@ -107,32 +104,30 @@ func TestInMemorySessionManager_Empty(t *testing.T) { func TestInMemorySessionManager_Retrieve(t *testing.T) { tests := []struct { - name string - given []*Session - lookupID string - remove string - wantID string + name string + given []*Session + lookupScreenName string + remove string + wantScreenName string }{ { name: "lookup finds match", given: []*Session{ { - id: "sess-id-1", screenName: "user-screen-name-1", }, { - id: "sess-id-2", screenName: "user-screen-name-2", }, }, - lookupID: "sess-id-2", - wantID: "sess-id-2", + lookupScreenName: "user-screen-name-2", + wantScreenName: "user-screen-name-2", }, { - name: "lookup does not find match", - given: []*Session{}, - lookupID: "sess-id-3", - wantID: "", + name: "lookup does not find match", + given: []*Session{}, + lookupScreenName: "user-screen-name-3", + wantScreenName: "", }, } for _, tt := range tests { @@ -140,14 +135,14 @@ func TestInMemorySessionManager_Retrieve(t *testing.T) { sm := NewInMemorySessionManager(slog.Default()) for _, sess := range tt.given { - sm.AddSession(sess.id, sess.screenName) + sm.AddSession(sess.screenName) } - have := sm.RetrieveSession(tt.lookupID) + have := sm.RetrieveSession(tt.lookupScreenName) if have == nil { - assert.Empty(t, tt.wantID) + assert.Empty(t, tt.wantScreenName) } else { - assert.Equal(t, tt.wantID, have.ID()) + assert.Equal(t, tt.wantScreenName, have.ScreenName()) } }) } @@ -156,9 +151,9 @@ func TestInMemorySessionManager_Retrieve(t *testing.T) { func TestInMemorySessionManager_RelayToScreenNames(t *testing.T) { sm := NewInMemorySessionManager(slog.Default()) - user1 := sm.AddSession("sess-id-1", "user-screen-name-1") - user2 := sm.AddSession("sess-id-2", "user-screen-name-2") - user3 := sm.AddSession("sess-id-3", "user-screen-name-3") + user1 := sm.AddSession("user-screen-name-1") + user2 := sm.AddSession("user-screen-name-2") + user3 := sm.AddSession("user-screen-name-3") want := wire.SNACMessage{Frame: wire.SNACFrame{FoodGroup: wire.ICBM}} @@ -185,8 +180,8 @@ func TestInMemorySessionManager_RelayToScreenNames(t *testing.T) { func TestInMemorySessionManager_Broadcast(t *testing.T) { sm := NewInMemorySessionManager(slog.Default()) - user1 := sm.AddSession("sess-id-1", "user-screen-name-1") - user2 := sm.AddSession("sess-id-2", "user-screen-name-2") + user1 := sm.AddSession("user-screen-name-1") + user2 := sm.AddSession("user-screen-name-2") want := wire.SNACMessage{Frame: wire.SNACFrame{FoodGroup: wire.ICBM}} @@ -206,8 +201,8 @@ func TestInMemorySessionManager_Broadcast(t *testing.T) { func TestInMemorySessionManager_Broadcast_SkipClosedSession(t *testing.T) { sm := NewInMemorySessionManager(slog.Default()) - user1 := sm.AddSession("sess-id-1", "user-screen-name-1") - user2 := sm.AddSession("sess-id-2", "user-screen-name-2") + user1 := sm.AddSession("user-screen-name-1") + user2 := sm.AddSession("user-screen-name-2") user2.Close() want := wire.SNACMessage{Frame: wire.SNACFrame{FoodGroup: wire.ICBM}} @@ -229,8 +224,8 @@ func TestInMemorySessionManager_Broadcast_SkipClosedSession(t *testing.T) { func TestInMemorySessionManager_RelayToScreenName_SessionExists(t *testing.T) { sm := NewInMemorySessionManager(slog.Default()) - user1 := sm.AddSession("sess-id-1", "user-screen-name-1") - user2 := sm.AddSession("sess-id-2", "user-screen-name-2") + user1 := sm.AddSession("user-screen-name-1") + user2 := sm.AddSession("user-screen-name-2") want := wire.SNACMessage{Frame: wire.SNACFrame{FoodGroup: wire.ICBM}} @@ -252,7 +247,7 @@ func TestInMemorySessionManager_RelayToScreenName_SessionExists(t *testing.T) { func TestInMemorySessionManager_RelayToScreenName_SessionNotExist(t *testing.T) { sm := NewInMemorySessionManager(slog.Default()) - user1 := sm.AddSession("sess-id-1", "user-screen-name-1") + user1 := sm.AddSession("user-screen-name-1") want := wire.SNACMessage{Frame: wire.SNACFrame{FoodGroup: wire.ICBM}} @@ -269,7 +264,7 @@ func TestInMemorySessionManager_RelayToScreenName_SessionNotExist(t *testing.T) func TestInMemorySessionManager_RelayToScreenName_SkipFullSession(t *testing.T) { sm := NewInMemorySessionManager(slog.Default()) - user1 := sm.AddSession("sess-id-1", "user-screen-name-1") + user1 := sm.AddSession("user-screen-name-1") msg := wire.SNACMessage{Frame: wire.SNACFrame{FoodGroup: wire.ICBM}} wantCount := 0 @@ -300,9 +295,9 @@ loop: func TestInMemorySessionManager_RelayToAllExcept(t *testing.T) { sm := NewInMemorySessionManager(slog.Default()) - user1 := sm.AddSession("sess-id-1", "user-screen-name-1") - user2 := sm.AddSession("sess-id-2", "user-screen-name-2") - user3 := sm.AddSession("sess-id-3", "user-screen-name-3") + user1 := sm.AddSession("user-screen-name-1") + user2 := sm.AddSession("user-screen-name-2") + user3 := sm.AddSession("user-screen-name-3") want := wire.SNACMessage{Frame: wire.SNACFrame{FoodGroup: wire.ICBM}} diff --git a/state/session_test.go b/state/session_test.go index 539644de..b64adf03 100644 --- a/state/session_test.go +++ b/state/session_test.go @@ -19,15 +19,6 @@ func TestSession_SetAndGetAwayMessage(t *testing.T) { assert.Equal(t, msg, s.AwayMessage()) } -func TestSession_SetAndGetID(t *testing.T) { - s := NewSession() - // make sure NewSession creates a default ID - assert.NotEmpty(t, s.SetID) - newID := "new-id" - s.SetID(newID) - assert.Equal(t, newID, s.ID()) -} - func TestSession_IncrementAndGetWarning(t *testing.T) { s := NewSession() assert.Zero(t, s.Warning())