From a04d408a0c445096bcf0b9c19320f351d955e52a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=80=97=E5=AD=90?= Date: Thu, 5 Sep 2024 00:55:09 +0800 Subject: [PATCH] feat: optimize session --- contracts/session/manager.go | 2 +- session/manager.go | 90 ++++++++++++++++++------ session/middleware/start_session.go | 6 +- session/middleware/start_session_test.go | 1 + session/service_provider.go | 31 -------- session/session.go | 8 +++ 6 files changed, 83 insertions(+), 55 deletions(-) diff --git a/contracts/session/manager.go b/contracts/session/manager.go index 4db8935ac..16b236010 100644 --- a/contracts/session/manager.go +++ b/contracts/session/manager.go @@ -6,5 +6,5 @@ type Manager interface { // Driver retrieves the session driver by name. Driver(name ...string) (Driver, error) // Extend extends the session manager with a custom driver. - Extend(driver string, handler func() Driver) Manager + Extend(driver string, handler Driver) (Manager, error) } diff --git a/session/manager.go b/session/manager.go index ed3af0854..f2b5facbe 100644 --- a/session/manager.go +++ b/session/manager.go @@ -2,33 +2,51 @@ package session import ( "fmt" + "sync" + "time" "github.com/goravel/framework/contracts/config" "github.com/goravel/framework/contracts/foundation" sessioncontract "github.com/goravel/framework/contracts/session" "github.com/goravel/framework/session/driver" + "github.com/goravel/framework/support/color" ) type Manager struct { - config config.Config - customDrivers map[string]sessioncontract.Driver - drivers map[string]sessioncontract.Driver - json foundation.Json + config config.Config + drivers map[string]sessioncontract.Driver + json foundation.Json + sessionPool sync.Pool } func NewManager(config config.Config, json foundation.Json) *Manager { manager := &Manager{ - config: config, - customDrivers: make(map[string]sessioncontract.Driver), - drivers: make(map[string]sessioncontract.Driver), - json: json, + config: config, + drivers: make(map[string]sessioncontract.Driver), + json: json, + sessionPool: sync.Pool{New: func() any { + return &Session{ + attributes: make(map[string]any), + } + }, + }, } - manager.registerDrivers() + manager.createDefaultDriver() return manager } func (m *Manager) BuildSession(handler sessioncontract.Driver, sessionID ...string) sessioncontract.Session { - return NewSession(m.config.GetString("session.cookie"), handler, m.json, sessionID...) + session := m.AcquireSession() + session.name = m.config.GetString("session.cookie") + session.driver = handler + session.json = m.json + if len(sessionID) > 0 { + session.SetID(sessionID[0]) + } else { + session.SetID("") + } + + return session } func (m *Manager) Driver(name ...string) (sessioncontract.Driver, error) { @@ -44,30 +62,58 @@ func (m *Manager) Driver(name ...string) (sessioncontract.Driver, error) { } if m.drivers[driverName] == nil { - if m.customDrivers[driverName] == nil { - return nil, fmt.Errorf("driver [%s] not supported", driverName) - } - - m.drivers[driverName] = m.customDrivers[driverName] + return nil, fmt.Errorf("driver [%s] not supported", driverName) } return m.drivers[driverName], nil } -func (m *Manager) Extend(driver string, handler func() sessioncontract.Driver) sessioncontract.Manager { - m.customDrivers[driver] = handler() - return m +func (m *Manager) Extend(driver string, handler sessioncontract.Driver) (sessioncontract.Manager, error) { + if m.drivers[driver] != nil { + return nil, fmt.Errorf("driver [%s] already exists", driver) + } + m.drivers[driver] = handler + m.startGcTimer(m.drivers[driver]) + return m, nil +} + +func (m *Manager) AcquireSession() *Session { + session := m.sessionPool.Get().(*Session) + return session +} + +func (m *Manager) ReleaseSession(session *Session) { + session.reset() + m.sessionPool.Put(session) } func (m *Manager) getDefaultDriver() string { return m.config.GetString("session.driver") } -func (m *Manager) createFileDriver() sessioncontract.Driver { +func (m *Manager) createDefaultDriver() { lifetime := m.config.GetInt("session.lifetime") - return driver.NewFile(m.config.GetString("session.files"), lifetime) + if _, err := m.Extend("file", driver.NewFile(m.config.GetString("session.files"), lifetime)); err != nil { + panic(fmt.Sprintf("failed to extend session manager: %v", err)) + } } -func (m *Manager) registerDrivers() { - m.drivers["file"] = m.createFileDriver() +// startGcTimer starts a garbage collection timer for the session driver. +func (m *Manager) startGcTimer(driver sessioncontract.Driver) { + interval := ConfigFacade.GetInt("session.gc_interval", 30) + if interval <= 0 { + // No need to start the timer if the interval is zero or negative + return + } + + ticker := time.NewTicker(time.Duration(interval) * time.Minute) + + go func() { + for range ticker.C { + lifetime := ConfigFacade.GetInt("session.lifetime") * 60 + if err := driver.Gc(lifetime); err != nil { + color.Red().Printf("Error performing garbage collection: %s\n", err) + } + } + }() } diff --git a/session/middleware/start_session.go b/session/middleware/start_session.go index 35c86ea3a..41d867a5a 100644 --- a/session/middleware/start_session.go +++ b/session/middleware/start_session.go @@ -50,8 +50,12 @@ func StartSession() http.Middleware { req.Next() // Save session - if err := s.Save(); err != nil { + if err = s.Save(); err != nil { color.Red().Printf("Error saving session: %s\n", err) } + + // Release session + // TODO - any better way to release the session? + session.SessionFacade.(*session.Manager).ReleaseSession(s.(*session.Session)) } } diff --git a/session/middleware/start_session_test.go b/session/middleware/start_session_test.go index 87a4dba5e..6d4c6fdb7 100644 --- a/session/middleware/start_session_test.go +++ b/session/middleware/start_session_test.go @@ -43,6 +43,7 @@ func TestStartSession(t *testing.T) { mockConfig := &configmocks.Config{} session.ConfigFacade = mockConfig mockConfig.On("GetInt", "session.lifetime").Return(120).Once() + mockConfig.On("GetInt", "session.gc_interval", 30).Return(30).Once() mockConfig.On("GetString", "session.files").Return("storage/framework/sessions").Once() session.SessionFacade = session.NewManager(mockConfig, json.NewJson()) server := httptest.NewServer(testHttpSessionMiddleware(nethttp.HandlerFunc(func(w nethttp.ResponseWriter, r *nethttp.Request) { diff --git a/session/service_provider.go b/session/service_provider.go index 1e4629787..a3c3d2c44 100644 --- a/session/service_provider.go +++ b/session/service_provider.go @@ -1,12 +1,9 @@ package session import ( - "time" - "github.com/goravel/framework/contracts/config" "github.com/goravel/framework/contracts/foundation" "github.com/goravel/framework/contracts/session" - "github.com/goravel/framework/support/color" ) var ( @@ -30,32 +27,4 @@ func (receiver *ServiceProvider) Register(app foundation.Application) { func (receiver *ServiceProvider) Boot(app foundation.Application) { SessionFacade = app.MakeSession() ConfigFacade = app.MakeConfig() - - driver, err := SessionFacade.Driver() - if err != nil { - color.Red().Println(err) - return - } - - startGcTimer(driver) -} - -// startGcTimer starts a garbage collection timer for the session driver. -func startGcTimer(driver session.Driver) { - interval := ConfigFacade.GetInt("session.gc_interval", 30) - if interval <= 0 { - // No need to start the timer if the interval is zero or negative - return - } - - ticker := time.NewTicker(time.Duration(interval) * time.Minute) - - go func() { - for range ticker.C { - lifetime := ConfigFacade.GetInt("session.lifetime") * 60 - if err := driver.Gc(lifetime); err != nil { - color.Red().Printf("Error performing garbage collection: %s\n", err) - } - } - }() } diff --git a/session/session.go b/session/session.go index decd16c98..c65304d87 100644 --- a/session/session.go +++ b/session/session.go @@ -272,6 +272,14 @@ func (s *Session) removeFromOldFlashData(keys ...string) { s.Put("_flash.old", old) } +func (s *Session) reset() { + s.id = "" + s.name = "" + s.attributes = make(map[string]any) + s.driver = nil + s.started = false +} + // toStringSlice converts an interface slice to a string slice. func toStringSlice(anySlice []any) []string { strSlice := make([]string, len(anySlice))