Skip to content

Commit

Permalink
chore: add session handler
Browse files Browse the repository at this point in the history
  • Loading branch information
katallaxie committed Feb 9, 2024
1 parent 95f1243 commit 2e2b3a8
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 16 deletions.
9 changes: 8 additions & 1 deletion adapters/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ type Session struct {
ExpiresAt time.Time `json:"expires_at"`
SessionToken string `json:"session_token"`
UserID uuid.UUID `json:"user_id"`
User User `json:"user" gorm:"foreignKey:UserID;constraint:OnDelete:CASCADE"`
User User `json:"user"`

CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
Expand Down Expand Up @@ -122,6 +122,8 @@ type Adapter interface {
GetSession(ctx context.Context, sessionToken string) (Session, error)
// UpdateSession updates a session.
UpdateSession(ctx context.Context, session Session) (Session, error)
// RefreshSession refreshes a session.
RefreshSession(ctx context.Context, session Session) (Session, error)
// DeleteSession deletes a session by session token.
DeleteSession(ctx context.Context, sessionToken string) error
// CreateVerificationToken creates a new verification token.
Expand Down Expand Up @@ -190,6 +192,11 @@ func (a *UnimplementedAdapter) UpdateSession(_ context.Context, session Session)
return Session{}, ErrUnimplemented
}

// RefreshSession refreshes a session.
func (a *UnimplementedAdapter) RefreshSession(_ context.Context, session Session) (Session, error) {
return Session{}, ErrUnimplemented
}

// DeleteSession deletes a session by session token.
func (a *UnimplementedAdapter) DeleteSession(_ context.Context, sessionToken string) error {
return ErrUnimplemented
Expand Down
12 changes: 11 additions & 1 deletion adapters/gorm/gorm.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func (a *gormAdapter) CreateUser(ctx context.Context, user adapters.User) (adapt
// GetSession ...
func (a *gormAdapter) GetSession(ctx context.Context, sessionToken string) (adapters.Session, error) {
var session adapters.Session
err := a.db.WithContext(ctx).Where("session_token = ?", sessionToken).First(&session).Error
err := a.db.WithContext(ctx).Preload("User").Where("session_token = ?", sessionToken).First(&session).Error
if err != nil {
return adapters.Session{}, err
}
Expand Down Expand Up @@ -91,6 +91,16 @@ func (a *gormAdapter) DeleteSession(ctx context.Context, sessionToken string) er
return a.db.WithContext(ctx).Where("session_token = ?", sessionToken).Delete(&adapters.Session{}).Error
}

// RefreshSession ...
func (a *gormAdapter) RefreshSession(ctx context.Context, session adapters.Session) (adapters.Session, error) {
err := a.db.WithContext(ctx).Model(&adapters.Session{}).Where("session_token = ?", session.SessionToken).Updates(&session).Error
if err != nil {
return adapters.Session{}, err
}

return session, nil
}

// DeleteUser ...
func (a *gormAdapter) DeleteUser(ctx context.Context, id uuid.UUID) error {
return a.db.WithContext(ctx).Where("id = ?", id).Delete(&adapters.User{}).Error
Expand Down
3 changes: 1 addition & 2 deletions examples/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,12 +171,11 @@ func run(ctx context.Context) error {
}

gothConfig := goth.Config{Adapter: ga, Secret: goth.GenerateKey()}
app.Use(goth.NewSessionHandler(gothConfig))

app.Get("/login", func(c *fiber.Ctx) error {
c.Set(fiber.HeaderContentType, fiber.MIMETextHTML)
return t.Execute(c.Response().BodyWriter(), providerIndex)
})
app.Get("/session", goth.NewSessionHandler(gothConfig))
app.Get("/login/:provider", goth.NewBeginAuthHandler(gothConfig))
app.Get("/auth/:provider/callback", goth.NewCompleteAuthHandler(gothConfig))
app.Get("/logout", goth.NewLogoutHandler(gothConfig))
Expand Down
46 changes: 34 additions & 12 deletions goth.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,21 +71,43 @@ func (SessionHandler) New(cfg Config) fiber.Handler {
return c.Next()
}

// cookie := c.Cookies(cfg.CookieName)
// if cookie == "" {
// return cfg.IndexHandler(c)
// }
cookie := c.Cookies(cfg.CookieName)
if cookie == "" {
return cfg.ErrorHandler(c, ErrMissingCookie)
}

// session, err := cfg.Adapter.GetSession(c.Context(), cookie)
// if err != nil {
// return cfg.ErrorHandler(c, err)
// }
session, err := cfg.Adapter.GetSession(c.Context(), cookie)
if err != nil {
return cfg.ErrorHandler(c, err)
}

// if !session.IsValid() {
// return cfg.IndexHandler(c)
// }
if !session.IsValid() {
cfg.ErrorHandler(c, err)
}

return c.Next()
duration, err := time.ParseDuration(cfg.Expiry)
if err != nil {
return cfg.ErrorHandler(c, err)
}
expires := time.Now().Add(duration)
session.ExpiresAt = expires

session, err = cfg.Adapter.RefreshSession(c.Context(), session)
if err != nil {
return cfg.ErrorHandler(c, err)
}

cookieValue := fasthttp.Cookie{}
cookieValue.SetKeyBytes([]byte(cfg.CookieName))
cookieValue.SetValueBytes([]byte(session.SessionToken))
cookieValue.SetHTTPOnly(true)
cookieValue.SetSameSite(fasthttp.CookieSameSiteLaxMode)
cookieValue.SetExpire(expires)
cookieValue.SetPath("/")

c.Response().Header.SetCookie(&cookieValue)

return c.JSON(session)
}
}

Expand Down

0 comments on commit 2e2b3a8

Please sign in to comment.