Skip to content

Commit

Permalink
chore: add session middleware
Browse files Browse the repository at this point in the history
  • Loading branch information
katallaxie committed Feb 9, 2024
1 parent 2daaa54 commit 8d5089c
Show file tree
Hide file tree
Showing 6 changed files with 107 additions and 52 deletions.
58 changes: 29 additions & 29 deletions adapters/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,11 @@ type Session struct {
DeletedAt *time.Time
}

// IsValid returns true if the session is valid.
func (s Session) IsValid() bool {
return s.ExpiresAt.After(time.Now())
}

// VerificationToken ...
type VerificationToken struct {
Token string `json:"token" gorm:"primaryKey"`
Expand All @@ -100,29 +105,29 @@ type Adapter interface {
// CreateUser creates a new user.
CreateUser(ctx context.Context, user User) (User, error)
// GetUser retrieves a user by ID.
GetUser(id uuid.UUID) (User, error)
GetUser(ctx context.Context, id uuid.UUID) (User, error)
// GetUserByEmail retrieves a user by email.
GetUserByEmail(email string) (User, error)
GetUserByEmail(ctx context.Context, email string) (User, error)
// UpdateUser updates a user.
UpdateUser(user User) (User, error)
UpdateUser(ctx context.Context, user User) (User, error)
// DeleteUser deletes a user by ID.
DeleteUser(id uuid.UUID) error
DeleteUser(ctx context.Context, id uuid.UUID) error
// LinkAccount links an account to a user.
LinkAccount(accountID, userID uuid.UUID) error
LinkAccount(ctx context.Context, accountID, userID uuid.UUID) error
// UnlinkAccount unlinks an account from a user.
UnlinkAccount(accountID, userID uuid.UUID) error
UnlinkAccount(ctx context.Context, accountID, userID uuid.UUID) error
// CreateSession creates a new session.
CreateSession(ctx context.Context, userID uuid.UUID, expires time.Time) (Session, error)
// GetSession retrieves a session by session token.
GetSession(sessionToken string) (Session, error)
GetSession(ctx context.Context, sessionToken string) (Session, error)
// UpdateSession updates a session.
UpdateSession(session Session) (Session, error)
UpdateSession(ctx context.Context, session Session) (Session, error)
// DeleteSession deletes a session by session token.
DeleteSession(sessionToken string) error
DeleteSession(ctx context.Context, sessionToken string) error
// CreateVerificationToken creates a new verification token.
CreateVerificationToken(verficationToken VerificationToken) (VerificationToken, error)
CreateVerificationToken(ctx context.Context, verficationToken VerificationToken) (VerificationToken, error)
// UseVerficationToken uses a verification token.
UseVerficationToken(identifier string, token string) (VerificationToken, error)
UseVerficationToken(ctx context.Context, identifier string, token string) (VerificationToken, error)
}

var _ Adapter = (*UnimplementedAdapter)(nil)
Expand All @@ -136,75 +141,70 @@ func (a *UnimplementedAdapter) CreateUser(_ context.Context, user User) (User, e
}

// GetUser retrieves a user by ID.
func (a *UnimplementedAdapter) GetUser(id uuid.UUID) (User, error) {
func (a *UnimplementedAdapter) GetUser(_ context.Context, id uuid.UUID) (User, error) {
return User{}, ErrUnimplemented
}

// GetUserByEmail retrieves a user by email.
func (a *UnimplementedAdapter) GetUserByEmail(email string) (User, error) {
func (a *UnimplementedAdapter) GetUserByEmail(_ context.Context, email string) (User, error) {
return User{}, ErrUnimplemented
}

// GetUserByAccount retrieves a user by account.
func (a *UnimplementedAdapter) GetUserByAccount(provider string, providerAccountID string) (User, error) {
func (a *UnimplementedAdapter) GetUserByAccount(_ context.Context, provider string, providerAccountID string) (User, error) {
return User{}, ErrUnimplemented
}

// UpdateUser updates a user.
func (a *UnimplementedAdapter) UpdateUser(user User) (User, error) {
func (a *UnimplementedAdapter) UpdateUser(_ context.Context, user User) (User, error) {
return User{}, ErrUnimplemented
}

// DeleteUser deletes a user by ID.
func (a *UnimplementedAdapter) DeleteUser(id uuid.UUID) error {
func (a *UnimplementedAdapter) DeleteUser(_ context.Context, id uuid.UUID) error {
return ErrUnimplemented
}

// LinkAccount links an account to a user.
func (a *UnimplementedAdapter) LinkAccount(accountID, userID uuid.UUID) error {
func (a *UnimplementedAdapter) LinkAccount(_ context.Context, accountID, userID uuid.UUID) error {
return ErrUnimplemented
}

// UnlinkAccount unlinks an account from a user.
func (a *UnimplementedAdapter) UnlinkAccount(accountID, userID uuid.UUID) error {
func (a *UnimplementedAdapter) UnlinkAccount(_ context.Context, accountID, userID uuid.UUID) error {
return ErrUnimplemented
}

// CreateSession creates a new session.
func (a *UnimplementedAdapter) CreateSession(ctx context.Context, userID uuid.UUID, expires time.Time) (Session, error) {
func (a *UnimplementedAdapter) CreateSession(_ context.Context, userID uuid.UUID, expires time.Time) (Session, error) {
return Session{}, ErrUnimplemented
}

// GetSession retrieves a session by session token.
func (a *UnimplementedAdapter) GetSession(sessionToken string) (Session, error) {
func (a *UnimplementedAdapter) GetSession(_ context.Context, sessionToken string) (Session, error) {
return Session{}, ErrUnimplemented
}

// UpdateSession updates a session.
func (a *UnimplementedAdapter) UpdateSession(session Session) (Session, error) {
func (a *UnimplementedAdapter) UpdateSession(_ context.Context, session Session) (Session, error) {
return Session{}, ErrUnimplemented
}

// DeleteSession deletes a session by session token.
func (a *UnimplementedAdapter) DeleteSession(sessionToken string) error {
func (a *UnimplementedAdapter) DeleteSession(_ context.Context, sessionToken string) error {
return ErrUnimplemented
}

// CreateVerificationToken creates a new verification token.
func (a *UnimplementedAdapter) CreateVerificationToken(verficationToken VerificationToken) (VerificationToken, error) {
func (a *UnimplementedAdapter) CreateVerificationToken(_ context.Context, erficationToken VerificationToken) (VerificationToken, error) {
return VerificationToken{}, ErrUnimplemented
}

// UseVerficationToken uses a verification token.
func (a *UnimplementedAdapter) UseVerficationToken(identifier string, token string) (VerificationToken, error) {
func (a *UnimplementedAdapter) UseVerficationToken(_ context.Context, identifier string, token string) (VerificationToken, error) {
return VerificationToken{}, ErrUnimplemented
}

// GetAccount retrieve by provider and provider account ID.
func (a *UnimplementedAdapter) GetAccount(provider string, providerAccountID string) (Account, error) {
return Account{}, ErrUnimplemented
}

// StringPtr returns a pointer to the string value passed in.
func StringPtr(s string) *string {
return &s
Expand Down
29 changes: 20 additions & 9 deletions adapters/gorm/gorm.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,21 @@ func (a *gormAdapter) CreateUser(ctx context.Context, user adapters.User) (adapt
return user, nil
}

// GetSession ...
func (a *gormAdapter) GetSession(ctx context.Context, sessionToken string) (adapters.Session, error) {
var session adapters.Session
err := a.db.WithContext(ctx).Where("session_token = ?", sessionToken).First(&session).Error
if err != nil {
return adapters.Session{}, err
}

return session, nil
}

// GetUser ...
func (a *gormAdapter) GetUser(id uuid.UUID) (adapters.User, error) {
func (a *gormAdapter) GetUser(ctx context.Context, id uuid.UUID) (adapters.User, error) {
var user adapters.User
err := a.db.Preload("Accounts").Where("id = ?", id).First(&user).Error
err := a.db.WithContext(ctx).Preload("Accounts").Where("id = ?", id).First(&user).Error
if err != nil {
return adapters.User{}, err
}
Expand All @@ -66,7 +77,7 @@ func (a *gormAdapter) GetUser(id uuid.UUID) (adapters.User, error) {

// CreateSession ...
func (a *gormAdapter) CreateSession(ctx context.Context, userID uuid.UUID, expires time.Time) (adapters.Session, error) {
session := adapters.Session{UserID: userID, SessionToken: uuid.NewString()}
session := adapters.Session{UserID: userID, SessionToken: uuid.NewString(), ExpiresAt: expires}
err := a.db.WithContext(ctx).Create(&session).Error
if err != nil {
return adapters.Session{}, err
Expand All @@ -76,16 +87,16 @@ func (a *gormAdapter) CreateSession(ctx context.Context, userID uuid.UUID, expir
}

// DeleteUser ...
func (a *gormAdapter) DeleteUser(id uuid.UUID) error {
return a.db.Where("id = ?", id).Delete(&adapters.User{}).Error
func (a *gormAdapter) DeleteUser(ctx context.Context, id uuid.UUID) error {
return a.db.WithContext(ctx).Where("id = ?", id).Delete(&adapters.User{}).Error
}

// LinkAccount ...
func (a *gormAdapter) LinkAccount(accountID, userID uuid.UUID) error {
return a.db.Model(&adapters.Account{}).Where("id = ?", accountID).Update("user_id", userID).Error
func (a *gormAdapter) LinkAccount(ctx context.Context, accountID, userID uuid.UUID) error {
return a.db.WithContext(ctx).Model(&adapters.Account{}).Where("id = ?", accountID).Update("user_id", userID).Error
}

// DeleteSession ...
func (a *gormAdapter) DeleteSession(sessionToken string) error {
return a.db.Where("session_token = ?", sessionToken).Delete(&adapters.Session{}).Error
func (a *gormAdapter) DeleteSession(ctx context.Context, sessionToken string) error {
return a.db.WithContext(ctx).Where("session_token = ?", sessionToken).Delete(&adapters.Session{}).Error
}
3 changes: 2 additions & 1 deletion examples/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ func run(ctx context.Context) error {
return err
}

providers.RegisterProvider(github.New(ga, os.Getenv("GITHUB_KEY"), os.Getenv("GITHUB_SECRET"), "http://localhost:3000/auth/github/callback"))
providers.RegisterProvider(github.New(os.Getenv("GITHUB_KEY"), os.Getenv("GITHUB_SECRET"), "http://localhost:3000/auth/github/callback"))

m := map[string]string{
"amazon": "Amazon",
Expand Down Expand Up @@ -171,6 +171,7 @@ func run(ctx context.Context) error {
}

gothConfig := goth.Config{Adapter: ga, Secret: goth.GenerateKey()}
app.Use(goth.NewSessionHandler(gothConfig))

app.Get("/login/:provider", goth.NewBeginAuthHandler(gothConfig))
app.Get("/auth/:provider/callback", goth.NewCompleteAuthHandler(gothConfig))
Expand Down
49 changes: 47 additions & 2 deletions goth.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,41 @@ func ProviderFromContext(c *fiber.Ctx) string {
return c.Get(fmt.Sprint(providerKey))
}

// SessionHandler is the default handler for the session.
type SessionHandler struct{}

// New creates a new handler to manage the session.
func (SessionHandler) New(cfg Config) fiber.Handler {
return func(c *fiber.Ctx) error {
if cfg.Next != nil && cfg.Next(c) {
return c.Next()
}

cookie := c.Cookies(cfg.CookieName)
if cookie == "" {
return cfg.ErrorHandler(c, ErrMissingSession)
}

session, err := cfg.Adapter.GetSession(c.Context(), cookie)
if err != nil {
return cfg.ErrorHandler(c, err)
}

if !session.IsValid() {
return fiber.ErrForbidden
}

return c.Next()
}
}

// NewSessionHandler returns a new default session handler.
func NewSessionHandler(config ...Config) fiber.Handler {
cfg := configDefault(config...)

return cfg.SessionHandler.New(cfg)
}

// BeginAuthHandler is the default handler to begin the authentication process.
type BeginAuthHandler struct{}

Expand All @@ -86,7 +121,7 @@ func (BeginAuthHandler) New(cfg Config) fiber.Handler {
return err
}

intent, err := provider.BeginAuth(state)
intent, err := provider.BeginAuth(c.Context(), cfg.Adapter, state)
if err != nil {
return err
}
Expand Down Expand Up @@ -134,7 +169,7 @@ func (CompleteAuthCompleteHandler) New(cfg Config) fiber.Handler {
return cfg.ErrorHandler(c, err)
}

user, err := provider.CompleteAuth(c.Context(), &Params{ctx: c})
user, err := provider.CompleteAuth(c.Context(), cfg.Adapter, &Params{ctx: c})
if err != nil {
return cfg.ErrorHandler(c, err)
}
Expand All @@ -145,6 +180,8 @@ func (CompleteAuthCompleteHandler) New(cfg Config) fiber.Handler {
}
expires := time.Now().Add(duration)

fmt.Println(expires)

session, err := cfg.Adapter.CreateSession(c.Context(), user.ID, expires)
if err != nil {
return cfg.ErrorHandler(c, err)
Expand Down Expand Up @@ -223,6 +260,9 @@ type Config struct {
// LogoutHandler is the handler to logout.
LogoutHandler GothHandler

// SessionHandler is the handler to manage the session.
SessionHandler GothHandler

// Response filter that is executed when responses need to returned.
ResponseFilter func(c *fiber.Ctx) error

Expand Down Expand Up @@ -258,6 +298,7 @@ var ConfigDefault = Config{
BeginAuthHandler: BeginAuthHandler{},
CompleteAuthHandler: CompleteAuthCompleteHandler{},
LogoutHandler: LogoutHandler{},
SessionHandler: SessionHandler{},
Encryptor: EncryptCookie,
Decryptor: DecryptCookie,
Expiry: "7h",
Expand Down Expand Up @@ -308,6 +349,10 @@ func configDefault(config ...Config) Config {
cfg.ErrorHandler = ConfigDefault.ErrorHandler
}

if cfg.SessionHandler == nil {
cfg.SessionHandler = ConfigDefault.SessionHandler
}

if cfg.Encryptor == nil {
cfg.Encryptor = ConfigDefault.Encryptor
}
Expand Down
12 changes: 5 additions & 7 deletions providers/github/github.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,20 +37,18 @@ type githubProvider struct {
providerType providers.ProviderType
client *http.Client
config *oauth2.Config
adapter adapters.Adapter

providers.UnimplementedProvider
}

// New creates a new GitHub provider.
func New(adapter adapters.Adapter, clientKey, secret, callbackURL string, scopes ...string) *githubProvider {
func New(clientKey, secret, callbackURL string, scopes ...string) *githubProvider {
p := &githubProvider{
id: "github",
name: "GitHub",
clientKey: clientKey,
secret: secret,
callbackURL: callbackURL,
adapter: adapter,
userURL: UserURL,
emailURL: EmailURL,
authURL: AuthURL,
Expand Down Expand Up @@ -91,7 +89,7 @@ func (a *authIntent) GetAuthURL() (string, error) {
}

// BeginAuth starts the authentication process.
func (g *githubProvider) BeginAuth(state string) (providers.AuthIntent, error) {
func (g *githubProvider) BeginAuth(ctx context.Context, adapter adapters.Adapter, state string) (providers.AuthIntent, error) {
verifier := oauth2.GenerateVerifier()
url := g.config.AuthCodeURL(state, oauth2.S256ChallengeOption(verifier))

Expand All @@ -101,7 +99,7 @@ func (g *githubProvider) BeginAuth(state string) (providers.AuthIntent, error) {
}

// CompleteAuth completes the authentication process.
func (g *githubProvider) CompleteAuth(ctx context.Context, params providers.AuthParams) (adapters.User, error) {
func (g *githubProvider) CompleteAuth(ctx context.Context, adapter adapters.Adapter, params providers.AuthParams) (adapters.User, error) {
u := struct {
ID int `json:"id"`
Email string `json:"email"`
Expand Down Expand Up @@ -157,12 +155,12 @@ func (g *githubProvider) CompleteAuth(ctx context.Context, params providers.Auth
},
}

user, err = g.adapter.CreateUser(ctx, user)
user, err = adapter.CreateUser(ctx, user)
if err != nil {
return adapters.User{}, err
}

user, err = g.adapter.GetUser(user.ID)
user, err = adapter.GetUser(ctx, user.ID)
if err != nil {
return adapters.User{}, err
}
Expand Down
8 changes: 4 additions & 4 deletions providers/providers.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ type Provider interface {
// Type returns the provider's type.
Type() ProviderType
// BeginAuth starts the authentication process.
BeginAuth(state string) (AuthIntent, error)
BeginAuth(ctx context.Context, adapter adapters.Adapter, state string) (AuthIntent, error)
// CompleteAuth completes the authentication process.
CompleteAuth(ctx context.Context, params AuthParams) (adapters.User, error)
CompleteAuth(ctx context.Context, adapter adapters.Adapter, params AuthParams) (adapters.User, error)
}

// AuthParams is the type of authentication parameters.
Expand Down Expand Up @@ -125,11 +125,11 @@ func (u *UnimplementedProvider) Debug(debug bool) {
}

// BeginAuth starts the authentication process.
func (u *UnimplementedProvider) BeginAuth(state string) (AuthIntent, error) {
func (u *UnimplementedProvider) BeginAuth(_ context.Context, _ adapters.Adapter, state string) (AuthIntent, error) {
return nil, ErrUnimplemented
}

// CompleteAuth completes the authentication process.
func (u *UnimplementedProvider) CompleteAuth(ctx context.Context, params AuthParams) (adapters.User, error) {
func (u *UnimplementedProvider) CompleteAuth(_ context.Context, _ adapters.Adapter, params AuthParams) (adapters.User, error) {
return adapters.User{}, ErrUnimplemented
}

0 comments on commit 8d5089c

Please sign in to comment.