diff --git a/contracts/session/manager.go b/contracts/session/manager.go index bf2dd9bfb..ff0a108cf 100644 --- a/contracts/session/manager.go +++ b/contracts/session/manager.go @@ -2,7 +2,7 @@ package session type Manager interface { // BuildSession constructs a new session with the given handler and session ID. - BuildSession(handler Driver, sessionID ...string) Session + BuildSession(handler Driver, sessionID ...string) (Session, error) // Driver retrieves the session driver by name. Driver(name ...string) (Driver, error) // Extend extends the session manager with a custom driver. diff --git a/contracts/session/session.go b/contracts/session/session.go index f4273f9eb..49edcaf95 100644 --- a/contracts/session/session.go +++ b/contracts/session/session.go @@ -42,6 +42,8 @@ type Session interface { Remove(key string) any // Save saves the session. Save() error + // SetDriver sets the session driver + SetDriver(driver Driver) Session // SetID sets the ID of the session. SetID(id string) Session // SetName sets the name of the session. diff --git a/mocks/session/Manager.go b/mocks/session/Manager.go index 9215db392..a4bab575a 100644 --- a/mocks/session/Manager.go +++ b/mocks/session/Manager.go @@ -21,7 +21,7 @@ func (_m *Manager) EXPECT() *Manager_Expecter { } // BuildSession provides a mock function with given fields: handler, sessionID -func (_m *Manager) BuildSession(handler session.Driver, sessionID ...string) session.Session { +func (_m *Manager) BuildSession(handler session.Driver, sessionID ...string) (session.Session, error) { _va := make([]interface{}, len(sessionID)) for _i := range sessionID { _va[_i] = sessionID[_i] @@ -36,6 +36,10 @@ func (_m *Manager) BuildSession(handler session.Driver, sessionID ...string) ses } var r0 session.Session + var r1 error + if rf, ok := ret.Get(0).(func(session.Driver, ...string) (session.Session, error)); ok { + return rf(handler, sessionID...) + } if rf, ok := ret.Get(0).(func(session.Driver, ...string) session.Session); ok { r0 = rf(handler, sessionID...) } else { @@ -44,7 +48,13 @@ func (_m *Manager) BuildSession(handler session.Driver, sessionID ...string) ses } } - return r0 + if rf, ok := ret.Get(1).(func(session.Driver, ...string) error); ok { + r1 = rf(handler, sessionID...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 } // Manager_BuildSession_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'BuildSession' @@ -73,12 +83,12 @@ func (_c *Manager_BuildSession_Call) Run(run func(handler session.Driver, sessio return _c } -func (_c *Manager_BuildSession_Call) Return(_a0 session.Session) *Manager_BuildSession_Call { - _c.Call.Return(_a0) +func (_c *Manager_BuildSession_Call) Return(_a0 session.Session, _a1 error) *Manager_BuildSession_Call { + _c.Call.Return(_a0, _a1) return _c } -func (_c *Manager_BuildSession_Call) RunAndReturn(run func(session.Driver, ...string) session.Session) *Manager_BuildSession_Call { +func (_c *Manager_BuildSession_Call) RunAndReturn(run func(session.Driver, ...string) (session.Session, error)) *Manager_BuildSession_Call { _c.Call.Return(run) return _c } diff --git a/mocks/session/Session.go b/mocks/session/Session.go index b41c379e1..b5fbc0975 100644 --- a/mocks/session/Session.go +++ b/mocks/session/Session.go @@ -1021,6 +1021,54 @@ func (_c *Session_Save_Call) RunAndReturn(run func() error) *Session_Save_Call { return _c } +// SetDriver provides a mock function with given fields: driver +func (_m *Session) SetDriver(driver session.Driver) session.Session { + ret := _m.Called(driver) + + if len(ret) == 0 { + panic("no return value specified for SetDriver") + } + + var r0 session.Session + if rf, ok := ret.Get(0).(func(session.Driver) session.Session); ok { + r0 = rf(driver) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(session.Session) + } + } + + return r0 +} + +// Session_SetDriver_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetDriver' +type Session_SetDriver_Call struct { + *mock.Call +} + +// SetDriver is a helper method to define mock.On call +// - driver session.Driver +func (_e *Session_Expecter) SetDriver(driver interface{}) *Session_SetDriver_Call { + return &Session_SetDriver_Call{Call: _e.mock.On("SetDriver", driver)} +} + +func (_c *Session_SetDriver_Call) Run(run func(driver session.Driver)) *Session_SetDriver_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(session.Driver)) + }) + return _c +} + +func (_c *Session_SetDriver_Call) Return(_a0 session.Session) *Session_SetDriver_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *Session_SetDriver_Call) RunAndReturn(run func(session.Driver) session.Session) *Session_SetDriver_Call { + _c.Call.Return(run) + return _c +} + // SetID provides a mock function with given fields: id func (_m *Session) SetID(id string) session.Session { ret := _m.Called(id) diff --git a/session/errors.go b/session/errors.go new file mode 100644 index 000000000..a8904053d --- /dev/null +++ b/session/errors.go @@ -0,0 +1,7 @@ +package session + +import "errors" + +var ( + ErrDriverNotSet = errors.New("session driver is not set") +) diff --git a/session/manager.go b/session/manager.go index fba6436ef..6355f726a 100644 --- a/session/manager.go +++ b/session/manager.go @@ -25,9 +25,7 @@ func NewManager(config config.Config, json foundation.Json) *Manager { drivers: make(map[string]sessioncontract.Driver), json: json, sessionPool: sync.Pool{New: func() any { - return &Session{ - attributes: make(map[string]any), - } + return NewSession("", nil, json) }, }, } @@ -35,26 +33,22 @@ func NewManager(config config.Config, json foundation.Json) *Manager { return manager } -func (m *Manager) AcquireSession() *Session { - session := m.sessionPool.Get().(*Session) - return session -} - -func (m *Manager) BuildSession(handler sessioncontract.Driver, sessionID ...string) sessioncontract.Session { +func (m *Manager) BuildSession(handler sessioncontract.Driver, sessionID ...string) (sessioncontract.Session, error) { if handler == nil { - panic("session driver cannot be nil") + return nil, ErrDriverNotSet } - session := m.AcquireSession() - session.setDriver(handler) - session.setJson(m.json) - session.SetName(m.config.GetString("session.cookie")) + + session := m.acquireSession() + session.SetDriver(handler). + SetName(m.config.GetString("session.cookie")) + if len(sessionID) > 0 { session.SetID(sessionID[0]) } else { session.SetID("") } - return session + return session, nil } func (m *Manager) Driver(name ...string) (sessioncontract.Driver, error) { @@ -86,9 +80,16 @@ func (m *Manager) Extend(driver string, handler func() sessioncontract.Driver) e } func (m *Manager) ReleaseSession(session sessioncontract.Session) { - s := session.(*Session) - s.reset() - m.sessionPool.Put(s) + session.Flush(). + SetDriver(nil). + SetName(""). + SetID("") + m.sessionPool.Put(session) +} + +func (m *Manager) acquireSession() sessioncontract.Session { + session := m.sessionPool.Get().(sessioncontract.Session) + return session } func (m *Manager) getDefaultDriver() string { diff --git a/session/manager_test.go b/session/manager_test.go index a45a5ffc1..821fffb88 100644 --- a/session/manager_test.go +++ b/session/manager_test.go @@ -85,6 +85,10 @@ func (s *ManagerTestSuite) TestExtend() { s.Nil(err) s.NotNil(driver) s.Equal("*session.CustomDriver", fmt.Sprintf("%T", driver)) + + // driver already exists + err = s.manager.Extend("test", NewCustomDriver) + s.Errorf(err, "driver [%s] already exists", "test") } func (s *ManagerTestSuite) TestBuildSession() { @@ -94,9 +98,22 @@ func (s *ManagerTestSuite) TestBuildSession() { s.Equal("*driver.File", fmt.Sprintf("%T", driver)) s.mockConfig.On("GetString", "session.cookie").Return("test_cookie").Once() - session := s.manager.BuildSession(driver) + session, err := s.manager.BuildSession(driver) + session.Put("name", "goravel") + + s.Nil(err) s.NotNil(session) s.Equal("test_cookie", session.GetName()) + s.Equal("goravel", session.Get("name")) + + s.manager.ReleaseSession(session) + s.Empty(session.GetName()) + s.Empty(session.All()) + + // driver is nil + session, err = s.manager.BuildSession(nil) + s.ErrorIs(err, ErrDriverNotSet) + s.Nil(session) } func (s *ManagerTestSuite) TestGetDefaultDriver() { diff --git a/session/middleware/start_session.go b/session/middleware/start_session.go index 9df9d46a0..d0f9b19a8 100644 --- a/session/middleware/start_session.go +++ b/session/middleware/start_session.go @@ -26,7 +26,13 @@ func StartSession() http.Middleware { } // Build session - s := session.SessionFacade.BuildSession(driver) + s, err := session.SessionFacade.BuildSession(driver) + if err != nil { + color.Red().Println(err) + req.Next() + return + } + s.SetID(req.Cookie(s.GetName())) // Start session diff --git a/session/session.go b/session/session.go index 7c796e09f..4e9fc0e4e 100644 --- a/session/session.go +++ b/session/session.go @@ -8,6 +8,7 @@ import ( "github.com/goravel/framework/contracts/foundation" sessioncontract "github.com/goravel/framework/contracts/session" + "github.com/goravel/framework/support/color" supportmaps "github.com/goravel/framework/support/maps" "github.com/goravel/framework/support/str" ) @@ -155,6 +156,10 @@ func (s *Session) Save() error { return err } + if err = s.validateDriver(); err != nil { + return err + } + if err = s.driver.Write(s.GetID(), string(data)); err != nil { return err } @@ -164,6 +169,11 @@ func (s *Session) Save() error { return nil } +func (s *Session) SetDriver(driver sessioncontract.Driver) sessioncontract.Session { + s.driver = driver + return s +} + func (s *Session) SetID(id string) sessioncontract.Session { if s.isValidID(id) { s.id = id @@ -210,6 +220,13 @@ func (s *Session) loadSession() { } } +func (s *Session) validateDriver() error { + if s.driver == nil { + return ErrDriverNotSet + } + return nil +} + func (s *Session) migrate(destroy ...bool) error { shouldDestroy := false if len(destroy) > 0 { @@ -217,8 +234,11 @@ func (s *Session) migrate(destroy ...bool) error { } if shouldDestroy { - err := s.driver.Destroy(s.GetID()) - if err != nil { + if err := s.validateDriver(); err != nil { + return err + } + + if err := s.driver.Destroy(s.GetID()); err != nil { return err } } @@ -229,12 +249,19 @@ func (s *Session) migrate(destroy ...bool) error { } func (s *Session) readFromHandler() map[string]any { + if err := s.validateDriver(); err != nil { + color.Red().Println(err) + return nil + } + value, err := s.driver.Read(s.GetID()) if err != nil { + color.Red().Println(err) return nil } var data map[string]any if err := s.json.Unmarshal([]byte(value), &data); err != nil { + color.Red().Println(err) return nil } return data @@ -272,22 +299,6 @@ func (s *Session) removeFromOldFlashData(keys ...string) { s.Put("_flash.old", old) } -func (s *Session) reset() { - s.id = "" - s.name = "" - s.attributes = make(map[string]any) - s.driver = nil - s.started = false -} - -func (s *Session) setDriver(driver sessioncontract.Driver) { - s.driver = driver -} - -func (s *Session) setJson(json foundation.Json) { - s.json = json -} - // toStringSlice converts an interface slice to a string slice. func toStringSlice(anySlice []any) []string { strSlice := make([]string, len(anySlice)) diff --git a/session/session_test.go b/session/session_test.go index c09a704dd..e13420e6c 100644 --- a/session/session_test.go +++ b/session/session_test.go @@ -162,6 +162,10 @@ func (s *SessionTestSuite) TestMigrate() { s.driver.On("Destroy", oldID).Return(nil).Once() s.Nil(s.session.migrate(true)) s.NotEqual(oldID, s.session.GetID()) + + // when driver is nil + s.session.SetDriver(nil) + s.ErrorIs(s.session.migrate(true), ErrDriverNotSet) } func (s *SessionTestSuite) TestMissing() { @@ -297,6 +301,10 @@ func (s *SessionTestSuite) TestSave() { s.Equal(errors.New("error"), s.session.Save()) s.True(s.session.started) + + // when driver is nil + s.session.SetDriver(nil) + s.ErrorIs(s.session.Save(), ErrDriverNotSet) } func (s *SessionTestSuite) TestSetID() {